From 80059eaa3b97670715eec294b9342b4ff8abe4ac Mon Sep 17 00:00:00 2001 From: jefyang1 <146495389+jefyang1@users.noreply.github.com> Date: Tue, 4 Mar 2025 10:32:25 -0800 Subject: [PATCH] Remove CK_USE_AMD_MFMA_GFX950 (#1935) * Add runtime check in example_gemm_xdl_streamk for gfx950 * Add runtime check in grouped conv fwd examples for gfx950 * Disable CK_USE_AMD_MFMA_GFX950 * Add new instances for gfx950 * Fix test_gemm_universal on gfx950 [ROCm/composable_kernel commit: c95bda93bacd4b8f809becd7c4b9e8b743d0ded8] --- CMakeLists.txt | 3 - example/01_gemm/gemm_xdl_streamk.cpp | 18 +- example/01_gemm/run_gemm_example.inc | 78 +---- example/01_gemm/run_gemm_example_streamk.inc | 270 ++++++++++++++++++ ...grouped_conv_fwd_bias_relu_add_example.inc | 100 ++++--- ...iple_d_welford_first_half_xdl_cshuffle.hpp | 5 +- ...wise_batched_gemm_gemm_xdl_cshuffle_v1.hpp | 5 +- ...iple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp | 11 +- ...ultiple_d_softmax_gemm_xdl_cshuffle_v1.hpp | 11 +- ...ched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp | 5 +- ...e_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp | 5 +- ...ridwise_gemm_multiple_abd_xdl_cshuffle.hpp | 20 +- ...emm_multiple_d_multiple_r_xdl_cshuffle.hpp | 5 +- .../gridwise_gemm_multiple_d_xdl_cshuffle.hpp | 20 +- ...ultiple_d_xdl_cshuffle_lds_direct_load.hpp | 20 +- ...se_gemm_multiple_d_xdl_splitk_cshuffle.hpp | 5 +- .../gridwise_gemm_reduce_xdl_cshuffle_v1.hpp | 5 +- ...e_gemm_split_k_multiple_d_xdl_cshuffle.hpp | 10 +- ...emm_split_k_multiple_d_xdl_cshuffle_v2.hpp | 5 +- ...idwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp | 5 +- .../gridwise_gemm_xdl_cshuffle_streamk_v3.hpp | 5 +- .../grid/gridwise_gemm_xdl_cshuffle_v1.hpp | 12 +- .../grid/gridwise_gemm_xdl_cshuffle_v2.hpp | 5 +- .../grid/gridwise_gemm_xdl_cshuffle_v3.hpp | 5 +- .../gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp | 5 +- ...ridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp | 14 +- .../gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp | 15 +- ..._gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp | 15 +- ...ridwise_gemm_xdl_layernorm_cshuffle_v1.hpp | 5 +- ...ridwise_gemm_xdl_waveletmodel_cshuffle.hpp | 5 +- .../grid/gridwise_gemm_xdlops_bwd_weight.hpp | 14 +- .../gpu/grid/gridwise_gemm_xdlops_v3r1.hpp | 5 +- .../tensor_operation/gpu/warp/xdlops_gemm.hpp | 76 ++--- ...conv_bwd_weight_two_stage_xdl_instance.hpp | 32 +-- ...ice_grouped_conv_fwd_xdl_comp_instance.hpp | 117 ++++++-- ...ed_conv_fwd_xdl_merged_groups_instance.hpp | 56 +++- ...m_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp | 23 +- ...m_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp | 23 +- ...m_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp | 40 ++- ...m_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp | 40 ++- ...6_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp | 8 +- ...f16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp | 4 +- ...6_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp | 4 +- ...f16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp | 4 +- ...6_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp | 46 ++- ..._c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp | 89 ++++-- ..._bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp | 95 ++++-- ...s_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp | 96 +++++-- ..._2_stage_f16_f16_f16_mk_nk_mn_instance.cpp | 23 +- ...uffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp | 32 ++- ..._shuffle_f16_f16_f16_km_kn_mn_instance.cpp | 23 +- ..._shuffle_f16_f16_f16_km_nk_mn_instance.cpp | 23 +- ..._shuffle_f16_f16_f16_mk_kn_mn_instance.cpp | 25 +- ..._shuffle_f16_f16_f16_mk_nk_mn_instance.cpp | 28 +- ...l_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp | 23 +- ...m_kn_mn_interwave_pipeline_v1_instance.cpp | 9 +- ...regular_interwave_pipeline_v1_instance.cpp | 2 - ...m_nk_mn_interwave_pipeline_v1_instance.cpp | 9 +- ...regular_interwave_pipeline_v1_instance.cpp | 2 - ...k_kn_mn_interwave_pipeline_v1_instance.cpp | 9 +- ...regular_interwave_pipeline_v1_instance.cpp | 2 - ...k_nk_mn_interwave_pipeline_v1_instance.cpp | 2 - ...regular_interwave_pipeline_v1_instance.cpp | 2 - ...16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp | 74 +++-- ...16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp | 74 +++-- ...16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp | 74 +++-- ...16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp | 74 +++-- ...e_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp | 65 ++++- ...e_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp | 65 ++++- ...e_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp | 65 ++++- ...e_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp | 65 ++++- ..._layernorm_f16_km_kn_mn_mn_mn_instance.cpp | 8 +- ..._layernorm_f16_km_nk_mn_mn_mn_instance.cpp | 8 +- ..._layernorm_f16_mk_kn_mn_mn_mn_instance.cpp | 8 +- ..._layernorm_f16_mk_nk_mn_mn_mn_instance.cpp | 8 +- ..._shuffle_f16_f16_f16_km_kn_mn_instance.cpp | 71 +++-- ..._shuffle_f16_f16_f16_km_nk_mn_instance.cpp | 71 +++-- ..._shuffle_f16_f16_f16_mk_kn_mn_instance.cpp | 63 +++- ..._shuffle_f16_f16_f16_mk_nk_mn_instance.cpp | 63 +++- ...f16_f16_mk_kn_mn_v1_interwave_instance.cpp | 4 +- ...f16_f16_mk_kn_mn_v1_irregular_instance.cpp | 4 +- ...f16_f16_mk_nk_mn_v1_interwave_instance.cpp | 4 +- ..._xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp | 8 +- ..._xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp | 32 ++- ...16_bf16_km_nk_mn_comp_default_instance.cpp | 8 + ...6_bf16_km_nk_mn_comp_kpadding_instance.cpp | 8 + ..._bf16_km_nk_mn_comp_mkpadding_instance.cpp | 9 + ...6_bf16_km_nk_mn_comp_mpadding_instance.cpp | 8 + ..._bf16_km_nk_mn_mem_v1_default_instance.cpp | 9 + ...bf16_km_nk_mn_mem_v1_kpadding_instance.cpp | 9 + ...f16_km_nk_mn_mem_v1_mkpadding_instance.cpp | 9 + ..._bf16_km_nk_mn_mem_v2_default_instance.cpp | 9 + ...bf16_km_nk_mn_mem_v2_kpadding_instance.cpp | 9 + ...f16_km_nk_mn_mem_v2_mkpadding_instance.cpp | 9 + ..._xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp | 35 +-- ...16_bf16_mk_kn_mn_comp_default_instance.cpp | 8 + ...6_bf16_mk_kn_mn_comp_kpadding_instance.cpp | 8 + ...bf16_mk_kn_mn_comp_mnkpadding_instance.cpp | 9 + ..._bf16_mk_kn_mn_comp_mnpadding_instance.cpp | 9 + ..._bf16_mk_kn_mn_mem_v1_default_instance.cpp | 9 + ...bf16_mk_kn_mn_mem_v1_kpadding_instance.cpp | 9 + ...16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp | 10 + ..._bf16_mk_kn_mn_mem_v2_default_instance.cpp | 9 + ...bf16_mk_kn_mn_mem_v2_kpadding_instance.cpp | 9 + ...16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp | 10 + ..._xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp | 46 +-- ...16_bf16_mk_nk_mn_comp_default_instance.cpp | 8 + ...6_bf16_mk_nk_mn_comp_kpadding_instance.cpp | 8 + ..._bf16_mk_nk_mn_mem_v1_default_instance.cpp | 9 + ...bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp | 9 + ..._bf16_mk_nk_mn_mem_v2_default_instance.cpp | 9 + ...bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp | 9 + ...emm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp | 18 +- ...f16_f16_mk_kn_mn_comp_default_instance.cpp | 8 + ...16_f16_mk_kn_mn_comp_kpadding_instance.cpp | 8 + ..._f16_mk_kn_mn_comp_mnkpadding_instance.cpp | 8 + ...6_f16_mk_kn_mn_comp_mnpadding_instance.cpp | 8 + ...emm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp | 32 ++- ...f16_f16_mk_nk_mn_comp_default_instance.cpp | 8 + ...16_f16_mk_nk_mn_comp_kpadding_instance.cpp | 8 + ...6_f16_mk_nk_mn_mem_v1_default_instance.cpp | 9 + ..._f16_mk_nk_mn_mem_v1_kpadding_instance.cpp | 9 + ...6_f16_mk_nk_mn_mem_v2_default_instance.cpp | 9 + ..._f16_mk_nk_mn_mem_v2_kpadding_instance.cpp | 9 + ...gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp | 18 +- ..._f8_f16_mk_kn_mn_comp_default_instance.cpp | 8 + ...f8_f16_mk_kn_mn_comp_kpadding_instance.cpp | 8 + ..._f16_mk_kn_mn_comp_mnkpadding_instance.cpp | 8 + ...8_f16_mk_kn_mn_comp_mnpadding_instance.cpp | 8 + ...gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp | 18 +- ..._f8_f16_mk_nk_mn_comp_default_instance.cpp | 8 + ...f8_f16_mk_nk_mn_comp_kpadding_instance.cpp | 8 + ...gemm_xdl_universal_f8_f8_bf16_mk_kn_mn.hpp | 22 +- ...f8_bf16_mk_kn_mn_comp_default_instance.cpp | 8 + ...8_bf16_mk_kn_mn_comp_kpadding_instance.cpp | 8 + ..._bf16_mk_kn_mn_comp_nkpadding_instance.cpp | 8 + ...gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp | 29 +- ...f8_bf16_mk_nk_mn_comp_default_instance.cpp | 8 + ...8_bf16_mk_nk_mn_comp_kpadding_instance.cpp | 8 + ..._xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp | 12 +- ...16_bf16_mk_nk_mn_comp_default_instance.cpp | 9 + ..._xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp | 20 +- ...16_bf16_mk_kn_mn_comp_default_instance.cpp | 9 + ...6_bf16_mk_kn_mn_comp_kpadding_instance.cpp | 9 + ...bf16_mk_kn_mn_comp_mnkpadding_instance.cpp | 9 + ..._bf16_mk_kn_mn_comp_mnpadding_instance.cpp | 9 + ...mm_xdl_universal_bf16_i8_bf16_mk_kn_mn.hpp | 25 +- ...i8_bf16_mk_kn_mn_comp_default_instance.cpp | 9 + ...8_bf16_mk_kn_mn_comp_kpadding_instance.cpp | 9 + ...bf16_mk_kn_mn_comp_mnkpadding_instance.cpp | 9 + ..._bf16_mk_kn_mn_comp_mnpadding_instance.cpp | 9 + ...emm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp | 20 +- ...f16_f16_mk_kn_mn_comp_default_instance.cpp | 9 + ...16_f16_mk_kn_mn_comp_kpadding_instance.cpp | 9 + ..._f16_mk_kn_mn_comp_mnkpadding_instance.cpp | 9 + ...6_f16_mk_kn_mn_comp_mnpadding_instance.cpp | 9 + ...versal_streamk_bf16_bf16_bf16_mk_kn_mn.hpp | 10 +- ...16_bf16_mk_kn_mn_comp_default_instance.cpp | 9 + ...6_bf16_mk_kn_mn_comp_kpadding_instance.cpp | 9 + ...bf16_mk_kn_mn_comp_mnkpadding_instance.cpp | 9 + ..._bf16_mk_kn_mn_comp_mnpadding_instance.cpp | 9 + ...versal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp | 13 +- ...16_bf16_mk_nk_mn_comp_default_instance.cpp | 9 + ...6_bf16_mk_nk_mn_comp_kpadding_instance.cpp | 9 + ...universal_streamk_f16_f16_f16_mk_kn_mn.hpp | 11 +- ...universal_streamk_f16_f16_f16_mk_nk_mn.hpp | 18 +- ...f16_f16_mk_nk_mn_comp_default_instance.cpp | 9 + ...16_f16_mk_nk_mn_comp_kpadding_instance.cpp | 9 + ..._f16_mk_nk_mn_comp_mnkpadding_instance.cpp | 9 + ...6_f16_mk_nk_mn_comp_mnpadding_instance.cpp | 9 + ...l_ngchw_gkyxc_ngkhw_bf16_comp_instance.cpp | 25 ++ ...dl_ngchw_gkyxc_ngkhw_f16_comp_instance.cpp | 25 ++ ...l_ngchw_gkyxc_ngkhw_int8_comp_instance.cpp | 25 ++ ...l_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp | 79 +++++ ...dl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp | 79 +++++ ...l_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp | 79 +++++ ...groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 56 ++-- ..._groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp | 56 ++-- ...dhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp | 57 ++++ ...ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp | 57 ++++ ...f16_f8_f16_mk_kn_mn_irregular_instance.cpp | 4 +- ...le_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp | 24 +- ...i8_bf16_mk_kn_mn_comp_default_instance.cpp | 12 + ...8_bf16_mk_kn_mn_comp_kpadding_instance.cpp | 12 + ...bf16_mk_kn_mn_comp_mnkpadding_instance.cpp | 12 + ..._bf16_mk_kn_mn_comp_mnpadding_instance.cpp | 12 + 186 files changed, 3272 insertions(+), 883 deletions(-) mode change 100755 => 100644 example/01_gemm/gemm_xdl_streamk.cpp create mode 100644 example/01_gemm/run_gemm_example_streamk.inc mode change 100755 => 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp mode change 100755 => 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instance.cpp mode change 100755 => 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp mode change 100755 => 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instance.cpp mode change 100755 => 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp mode change 100755 => 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instance.cpp mode change 100755 => 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp mode change 100755 => 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index e90f893de0..3be508382a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -202,9 +202,6 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx9 add_definitions(-DCK_USE_GFX94) set(CK_USE_GFX94 "ON") endif() -if (SUPPORTED_GPU_TARGETS MATCHES "gfx95") - add_definitions(-DCK_USE_AMD_MFMA_GFX950) -endif() if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") message("Enabling WMMA instances") add_definitions(-DCK_USE_WMMA) diff --git a/example/01_gemm/gemm_xdl_streamk.cpp b/example/01_gemm/gemm_xdl_streamk.cpp old mode 100755 new mode 100644 index 01542c4775..41665a79b7 --- a/example/01_gemm/gemm_xdl_streamk.cpp +++ b/example/01_gemm/gemm_xdl_streamk.cpp @@ -27,22 +27,24 @@ using DeviceGemmStreamK = ck::tensor_operation::device::DeviceGemmXdlStreamK // ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ######| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) - < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; -#else // defined(CK_USE_AMD_MFMA_GFX950) < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; // < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8>; // < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>; // < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 128, 32, 128, 4, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8>; -#endif // defined(CK_USE_AMD_MFMA_GFX950) +// instance for double rate mfma instruction on gfx950 +using DeviceGemmStreamK2 = ck::tensor_operation::device::DeviceGemmXdlStreamK +// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; - -// // clang-format on // clang-format on -using DeviceGemmInstance = DeviceGemmStreamK; +using DeviceGemmInstance = DeviceGemmStreamK; +using DeviceGemmInstance2 = DeviceGemmStreamK2; using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; @@ -58,6 +60,6 @@ using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; -#include "run_gemm_example.inc" +#include "run_gemm_example_streamk.inc" int main(int argc, char* argv[]) { return !run_gemm_streamk_example(argc, argv); } diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index 4371af6244..c064ed500c 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -3,8 +3,6 @@ #pragma once -#include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp" - template bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) { @@ -124,23 +122,12 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) auto b_element_op = BElementOp{}; auto c_element_op = CElementOp{}; - using BaseStreamK = ck::tensor_operation::device::DeviceGemmStreamK; - // do GEMM auto gemm = DeviceGemmInstance{}; auto invoker = gemm.MakeInvoker(); float ave_time = 0; - if constexpr(std::is_same::value && - !std::is_base_of::value) + if constexpr(std::is_same::value) { auto argument = gemm.MakeArgument( #ifdef BUILD_INT4_EXAMPLE @@ -171,61 +158,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); } - else if constexpr(std::is_same::value && - std::is_base_of::value) - { - auto argument = gemm.MakeArgument( -#ifdef BUILD_INT4_EXAMPLE - static_cast(a_m_k_device_buf.GetDeviceBuffer()), - static_cast(b_k_n_device_buf.GetDeviceBuffer()), - static_cast(c_m_n_device_buf.GetDeviceBuffer()), -#else - static_cast(a_m_k_device_buf.GetDeviceBuffer()), - static_cast(b_k_n_device_buf.GetDeviceBuffer()), - static_cast(c_m_n_device_buf.GetDeviceBuffer()), -#endif - M, - N, - K, - StrideA, - StrideB, - StrideC, - a_element_op, - b_element_op, - c_element_op, - problem_size.NumSKBlocks); - - if(!gemm.IsSupportedArgument(argument)) - { - std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; - - return true; - } - - std::size_t workspace_size = gemm.GetWorkSpaceSize(&argument); - if(workspace_size != 0) - { - workspace.Realloc(workspace_size); - gemm.SetWorkSpacePointer(&argument, workspace.GetDeviceBuffer()); - } - - ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); - -#if 0 - // TODO!!!!! - if(workspace_size != 0){ - float * ws_ptr = reinterpret_cast(malloc(workspace_size)); - size_t ws_dwords = workspace_size / sizeof(float); - workspace.FromDevice(ws_ptr); - - for(size_t i = 0; i < ws_dwords; i++) { - uint32_t rere = reinterpret_cast(ws_ptr)[i]; - printf("%4lu : %f(0x%08x)\n", i, ws_ptr[i], rere); - } - free(ws_ptr); - } -#endif - } else { // When the Problem Type and Problem Size does not fit. @@ -319,11 +251,3 @@ bool run_gemm_example(int argc, char* argv[]) return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config); } - -bool run_gemm_streamk_example(int argc, char* argv[]) -{ - ProblemSizeStreamK problem_size; - ExecutionConfig config; - - return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config); -} diff --git a/example/01_gemm/run_gemm_example_streamk.inc b/example/01_gemm/run_gemm_example_streamk.inc new file mode 100644 index 0000000000..438afcf71a --- /dev/null +++ b/example/01_gemm/run_gemm_example_streamk.inc @@ -0,0 +1,270 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/host_utility/device_prop.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp" + +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ +#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) + static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); +#endif + + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + switch(config.init_method) + { + case 0: + ck::utils::FillConstant{ck::type_convert(1.f)}(a_m_k); + ck::utils::FillConstant{ck::type_convert(1.f)}(b_k_n); + break; + case 1: + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); + break; + case 2: + ck::utils::FillUniformDistribution{-1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistribution{-1.f, 1.f}(b_k_n); + break; + case 3: + ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); + break; + case 4: + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(b_k_n); + break; + case 5: + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(b_k_n); + break; + default: + ck::utils::FillUniformDistribution{-0.1f, 0.1f}(a_m_k); + ck::utils::FillUniformDistribution{-0.1f, 0.1f}(b_k_n); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_ref_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + +#ifdef BUILD_INT4_EXAMPLE + DeviceMem a_m_k_device_buf(sizeof(KernelADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(KernelBDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(KernelCDataType) * + c_m_n_device_result.mDesc.GetElementSpaceSize()); + + const Tensor a_m_k_converted(a_m_k); + const Tensor b_k_n_converted(b_k_n); + + a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data()); +#else + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_ref_buf(sizeof(CDataType) * + c_m_n_device_ref_result.mDesc.GetElementSpaceSize()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); +#endif + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + using BaseStreamK = ck::tensor_operation::device::DeviceGemmStreamK; + + // do GEMM + static_assert(std::is_base_of::value && + std::is_base_of::value); + auto gemm = DeviceGemmInstance{}; + auto gemm2 = DeviceGemmInstance2{}; // instance for double rate mfma instruction + BaseStreamK* op_ptr = (ck::get_device_name() == "gfx950") ? static_cast(&gemm2) + : static_cast(&gemm); + + float ave_time = 0; + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + auto argument_ptr = op_ptr->MakeArgumentPointer( +#ifdef BUILD_INT4_EXAMPLE + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), +#else + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), +#endif + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + problem_size.NumSKBlocks); + + if(!op_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::cerr << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + auto argument = argument_ptr.get(); + std::size_t workspace_size = op_ptr->GetWorkSpaceSize(argument); + if(workspace_size != 0) + { + workspace.Realloc(workspace_size); + op_ptr->SetWorkSpacePointer(argument, workspace.GetDeviceBuffer()); + } + + ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op_ptr->GetTypeString() << std::endl; + + bool pass = true; + + if((config.do_verification == 1) || (config.do_verification == 3)) + { + // CPU verification + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + std::cout << "Running verification on CPU." << std::endl; + ref_invoker.Run(ref_argument); + +#ifdef BUILD_INT4_EXAMPLE + Tensor c_m_n_device_result_converted(c_m_n_host_result.mDesc); + + c_m_n_device_buf.FromDevice(c_m_n_device_result_converted.mData.data()); + + c_m_n_device_result = c_m_n_device_result_converted.CopyAsType(); + + return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result); +#else + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); +#endif + } + + if((config.do_verification == 2) || (config.do_verification == 3)) + { + // GPU verification + auto ref_gemm_gpu = ReferenceGemmInstanceGPU{}; + auto ref_invoker_gpu = ref_gemm_gpu.MakeInvoker(); + + auto ref_argument_gpu = ref_gemm_gpu.MakeArgument( + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_ref_buf.GetDeviceBuffer()), + M, + N, + K, + a_element_op, + b_element_op, + c_element_op); + + std::cout << "Running verification on GPU." << std::endl; + ref_invoker_gpu.Run(ref_argument_gpu, StreamConfig{}); + + c_m_n_device_ref_buf.FromDevice(c_m_n_device_ref_result.mData.data()); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_device_ref_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + return pass == true; +} + +bool run_gemm_streamk_example(int argc, char* argv[]) +{ + ProblemSizeStreamK problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config); +} diff --git a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc index ce42a20be7..627e20e245 100644 --- a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc +++ b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc @@ -32,9 +32,9 @@ using BiasLayout = typename LayoutSettingSelector::BiasLayout; template using ResidualLayout = typename LayoutSettingSelector::ResidualLayout; -#if defined(CK_USE_AMD_MFMA_GFX950) +// instance for double rate mfma on gfx950 (vs gfx942) template -using DeviceConvFwdInstance = +using DeviceConvFwdInstance2 = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, InputLayout, @@ -55,14 +55,14 @@ using DeviceConvFwdInstance = 1, // 256, // BlockSize 128, // MPerBlock - 256, // NPerBlock + 64, // NPerBlock 64, // KPerBlock 16, // AK1 16, // BK1 32, // MPerXdl 32, // NPerXdl 2, // MXdlPerWave - 4, // NXdlPerWave + 1, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -81,7 +81,7 @@ using DeviceConvFwdInstance = 1, S<1, 16, 1, 16>, 4>; -#else // defined(CK_USE_AMD_MFMA_GFX950) +// instance for gfx942- template using DeviceConvFwdInstance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< @@ -104,14 +104,14 @@ using DeviceConvFwdInstance = 1, // 256, // BlockSize 128, // MPerBlock - 256, // NPerBlock - 16, // KPerBlock + 128, // NPerBlock + 32, // KPerBlock 4, // AK1 4, // BK1 32, // MPerXdl 32, // NPerXdl 2, // MXdlPerWave - 4, // NXdlPerWave + 2, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -130,7 +130,6 @@ using DeviceConvFwdInstance = 1, S<1, 16, 1, 16>, 4>; -#endif // defined(CK_USE_AMD_MFMA_GFX950) template using HostConvFwdInstance = ck::tensor_operation::host::ReferenceConvFwd{}; - auto invoker = conv.MakeInvoker(); - auto argument = - conv.MakeArgument(in_device_buf.GetDeviceBuffer(), - wei_device_buf.GetDeviceBuffer(), - std::array{bias_device_buf.GetDeviceBuffer(), - residual_device_buf.GetDeviceBuffer()}, - out_device_buf.GetDeviceBuffer(), - a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - std::array, 2>{ - {d0_g_n_k_wos_lengths, d1_g_n_k_wos_lengths}}, - std::array, 2>{ - {d0_g_n_k_wos_strides, d1_g_n_k_wos_strides}}, - e_g_n_k_wos_lengths, - e_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}); + using BaseGroupedConvFwdMultipleABD = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD< + NDimSpatial, + InputLayout, + WeightLayout, + ck::Tuple, ResidualLayout>, + OutputLayout, + InKernelDataType, + WeiKernelDataType, + ck::Tuple, + OutKernelDataType, + InElementOp, + WeiElementOp, + OutElementOp, + InKernelDataType, // AComputeDataType + InKernelDataType>; // BComputeDataType - if(!conv.IsSupportedArgument(argument)) + static_assert( + std::is_base_of>::value && + std::is_base_of>::value); + + auto conv = DeviceConvFwdInstance{}; // instance for gfx942- + auto conv2 = DeviceConvFwdInstance2{}; // instance for double rate mfma instruction + // on gfx950 + BaseGroupedConvFwdMultipleABD* op_ptr = + (ck::get_device_name() == "gfx950") ? static_cast(&conv2) + : static_cast(&conv); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + auto argument_ptr = op_ptr->MakeArgumentPointer( + in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + std::array{bias_device_buf.GetDeviceBuffer(), + residual_device_buf.GetDeviceBuffer()}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, 2>{ + {d0_g_n_k_wos_lengths, d1_g_n_k_wos_lengths}}, + std::array, 2>{ + {d0_g_n_k_wos_strides, d1_g_n_k_wos_strides}}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + if(!op_ptr->IsSupportedArgument(argument_ptr.get())) { throw std::runtime_error( "wrong! device_conv with the specified compilation parameters does " "not support this Conv problem"); } - float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float avg_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, config.time_kernel}); std::size_t flop = conv_param.GetFlops(); std::size_t num_btype = conv_param.GetByte(); @@ -276,7 +302,7 @@ bool run_grouped_conv_fwd_bias_relu_add(const ExecutionConfig& config, float tflops = static_cast(flop) / 1.E9 / avg_time; float gb_per_sec = num_btype / 1.E6 / avg_time; std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << conv.GetTypeString() << std::endl; + << op_ptr->GetTypeString() << std::endl; if(config.do_verification) { diff --git a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp index f4d0989088..d728360c55 100644 --- a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp @@ -517,8 +517,9 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle // sanity check constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); constexpr bool is_single_rate_mfma = - ((is_same::value || is_same::value) && - lcm_AK1_BK1 <= 4) + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8)) ? true : false; constexpr index_t KPack = diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp index 55e254e015..50b4a734fa 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp @@ -450,8 +450,9 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle // sanity check constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); constexpr bool is_single_rate_mfma = - ((is_same::value || is_same::value) && - lcm_AK1_BK1 <= 4) + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8)) ? true : false; constexpr index_t KPack = math::max( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp index fd16927cc1..79a9410898 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp @@ -361,9 +361,11 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle const auto M = d0_grid_desc_m_n.GetLength(I0); const auto N = d0_grid_desc_m_n.GetLength(I1); + constexpr auto lcm_A0K1_B0K1 = math::lcm(A0K1, B0K1); constexpr bool is_single_rate_mfma = - ((is_same::value || is_same::value) && - math::lcm(A0K1, B0K1) <= 4) + (((is_same::value || is_same::value) && + lcm_A0K1_B0K1 <= 4) || + (is_same::value && lcm_A0K1_B0K1 <= 8)) ? true : false; constexpr auto mfma = MfmaSelector::value || is_same::value) && - lcm_A0K1_B0K1 <= 4) + (((is_same::value || is_same::value) && + lcm_A0K1_B0K1 <= 4) || + (is_same::value && lcm_A0K1_B0K1 <= 8)) ? true : false; constexpr index_t KPack = diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp index 1f7458e68f..d15767f658 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp @@ -343,9 +343,11 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle const auto M = d0_grid_desc_m_n.GetLength(I0); const auto N = d0_grid_desc_m_n.GetLength(I1); + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); constexpr bool is_single_rate_mfma = - ((is_same::value || is_same::value) && - math::lcm(AK1, BK1) <= 4) + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8)) ? true : false; constexpr auto mfma = @@ -560,8 +562,9 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle // sanity check constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); constexpr bool is_single_rate_mfma = - ((is_same::value || is_same::value) && - lcm_AK1_BK1 <= 4) + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8)) ? true : false; constexpr index_t KPack = math::max( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp index f7746b470f..a11d696019 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp @@ -471,8 +471,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle // sanity check constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); constexpr bool is_single_rate_mfma = - ((is_same::value || is_same::value) && - lcm_AK1_BK1 <= 4) + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8)) ? true : false; constexpr index_t KPack = math::max( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp index 8b3f51b9b0..ab97a940a8 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp @@ -500,8 +500,9 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // sanity check constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); constexpr bool is_single_rate_mfma = - ((is_same::value || is_same::value) && - lcm_AK1_BK1 <= 4) + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8)) ? true : false; constexpr index_t KPack = math::max( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp index 344656b13f..79ab3acd92 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp @@ -674,10 +674,22 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register // sanity check - constexpr index_t KPack = math::max( - math::lcm(AK1, BK1), - MfmaSelector::selected_mfma - .k_per_blk); + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); + constexpr bool is_single_rate_mfma = + (((is_same::value || + is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8)) + ? true + : false; + + constexpr index_t KPack = + math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp index 60ee78528d..0e51c6904c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -466,8 +466,9 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // sanity check constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); constexpr bool is_single_rate_mfma = - ((is_same::value || is_same::value) && - lcm_AK1_BK1 <= 4) + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8)) ? true : false; constexpr index_t KPack = math::max( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index 060f6d5d15..d54a00eaa2 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -635,10 +635,22 @@ struct GridwiseGemmMultipleD_xdl_cshuffle // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register // sanity check - constexpr index_t KPack = math::max( - math::lcm(AK1, BK1), - MfmaSelector::selected_mfma - .k_per_blk); + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); + constexpr bool is_single_rate_mfma = + (((is_same::value || + is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8)) + ? true + : false; + + constexpr index_t KPack = + math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp index b4c5d004c4..57b9b02548 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp @@ -600,10 +600,22 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad // b_mtx[K0PerBlock, NPerBlock] is in LDS // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register - constexpr index_t KPack = math::max( - math::lcm(AK1, BK1), - MfmaSelector::selected_mfma - .k_per_blk); + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); + constexpr bool is_single_rate_mfma = + (((is_same::value || + is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8)) + ? true + : false; + + constexpr index_t KPack = + math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp index d1d97da5b0..88d6be234c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp @@ -601,8 +601,9 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle // sanity check constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); constexpr bool is_single_rate_mfma = - ((is_same::value || is_same::value) && - lcm_AK1_BK1 <= 4) + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8)) ? true : false; constexpr index_t KPack = math::max( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp index 7105fa7012..56581256dc 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -453,8 +453,9 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // sanity check constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); constexpr bool is_single_rate_mfma = - ((is_same::value || is_same::value) && - lcm_AK1_BK1 <= 4) + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8)) ? true : false; constexpr index_t KPack = math::max( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp index 3429c20e73..23b4aec3b0 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp @@ -583,8 +583,9 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle // sanity check constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); constexpr bool is_single_rate_mfma = - ((is_same::value || is_same::value) && - lcm_AK1_BK1 <= 4) + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8)) ? true : false; constexpr index_t KPack = @@ -1015,8 +1016,9 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle // sanity check constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); constexpr bool is_single_rate_mfma = - ((is_same::value || is_same::value) && - lcm_AK1_BK1 <= 4) + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8)) ? true : false; constexpr index_t KPack = diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp index d7c87a170c..44c1e936bd 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp @@ -597,8 +597,9 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle // sanity check constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); constexpr bool is_single_rate_mfma = - ((is_same::value || is_same::value) && - lcm_AK1_BK1 <= 4) + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8)) ? true : false; constexpr index_t KPack = math::max( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp index 08d9386d72..4f5fedcd83 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp @@ -81,8 +81,9 @@ struct GridwiseGemm_xdl_cshuffle_v3 static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); static constexpr bool is_single_rate_mfma = - ((is_same::value || is_same::value) && - lcm_AK1_BK1 <= 4) + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8)) ? true : false; static constexpr index_t KPack = diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp index fcb12f4a14..cc8ae1806a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp @@ -141,8 +141,9 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); static constexpr bool is_single_rate_mfma = - ((is_same::value || is_same::value) && - lcm_AK1_BK1 <= 4) + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8)) ? true : false; static constexpr index_t KPack = diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp index c293d64ef0..240bc464e1 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -810,9 +810,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register // sanity check + constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8)) + ? true + : false; constexpr index_t KPack = math::max( - math::lcm(AK1Number, BK1Number), - MfmaSelector::selected_mfma.k_per_blk); + lcm_AK1_BK1, + MfmaSelector:: + selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp index af91721c8a..c7d44e842d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp @@ -871,8 +871,9 @@ struct GridwiseGemm_xdl_cshuffle_v2 // sanity check constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); constexpr bool is_single_rate_mfma = - ((is_same::value || is_same::value) && - lcm_AK1_BK1 <= 4) + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8)) ? true : false; constexpr index_t KPack = math::max( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index d4c915aa5e..55639f4aee 100755 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -149,8 +149,9 @@ struct GridwiseGemm_xdl_cshuffle_v3 static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); static constexpr bool is_single_rate_mfma = - ((is_same::value || is_same::value) && - lcm_AK1_BK1 <= 4) + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8)) ? true : false; static constexpr index_t KPack = diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp index 2e62110416..27818b6964 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp @@ -157,8 +157,9 @@ struct GridwiseGemm_xdl_cshuffle_v3 static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); static constexpr bool is_single_rate_mfma = - ((is_same::value || is_same::value) && - lcm_AK1_BK1 <= 4) + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8)) ? true : false; static constexpr index_t KPack = diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp index f9071bd29d..b805f600d5 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp @@ -193,9 +193,17 @@ struct GridwiseGemm_xdl_cshuffle_v3 using BsGridPointer = decltype(MakeBsGridPointer()); using DsGridPointer = decltype(MakeDsGridPointer()); - static constexpr index_t KPack = math::max( - math::lcm(AK1Number, BK1Number), - MfmaSelector::selected_mfma.k_per_blk); + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8)) + ? true + : false; + static constexpr index_t KPack = + math::max(lcm_AK1_BK1, + MfmaSelector:: + selected_mfma.k_per_blk); using ThisThreadBlock = ThisThreadBlock; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp index a9e73bf461..4163d1d01a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp @@ -179,9 +179,18 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 using DsGridPointer = decltype(MakeDsGridPointer()); - static constexpr index_t KPack = math::max( - math::lcm(AK1Number, BK1Number), - MfmaSelector::selected_mfma.k_per_blk); + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8)) + ? true + : false; + + static constexpr index_t KPack = + math::max(lcm_AK1_BK1, + MfmaSelector:: + selected_mfma.k_per_blk); using ThisThreadBlock = ThisThreadBlock; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp index 813acfa656..d10db3225e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp @@ -149,9 +149,18 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 using DsGridPointer = decltype(MakeDsGridPointer()); - static constexpr index_t KPack = math::max( - math::lcm(AK1Number, BK1Number), - MfmaSelector::selected_mfma.k_per_blk); + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8)) + ? true + : false; + + static constexpr index_t KPack = + math::max(lcm_AK1_BK1, + MfmaSelector:: + selected_mfma.k_per_blk); using ThisThreadBlock = ThisThreadBlock; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp index 0a62464cc2..b435fd5d5a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp @@ -491,8 +491,9 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // sanity check constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); constexpr bool is_single_rate_mfma = - ((is_same::value || is_same::value) && - lcm_AK1_BK1 <= 4) + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8)) ? true : false; constexpr index_t KPack = math::max( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp index 6a4b1cc14b..ad65e75ef9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp @@ -489,8 +489,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle // branch early for math wave constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); constexpr bool is_single_rate_mfma = - ((is_same::value || is_same::value) && - lcm_AK1_BK1 <= 4) + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8)) ? true : false; constexpr index_t KPack = math::max( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index b41e747a3a..168c553180 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -741,11 +741,17 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register // sanity check + constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + K1 <= 4) || + (is_same::value && K1 <= 8)) + ? true + : false; - constexpr index_t KPack = - math::max(K1, - MfmaSelector::selected_mfma - .k_per_blk); + constexpr index_t KPack = math::max( + K1, + MfmaSelector:: + selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1::value || is_same::value) && - lcm_AK1_BK1 <= 4) + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8)) ? true : false; constexpr index_t k_pack = math::max( diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 8c0b950941..9f6c0b5648 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -1053,40 +1053,49 @@ struct MfmaSelector #endif } + template <> + constexpr auto GetMfma() + { #if defined(__gfx950__) - template <> - constexpr auto GetMfma() - { return MfmaInstr::mfma_i32_32x32x32i8; - } - template <> - constexpr auto GetMfma() - { - return MfmaInstr::mfma_i32_16x16x64i8; - } #elif defined(__gfx942__) - template <> - constexpr auto GetMfma() - { return MfmaInstr::mfma_i32_32x32x16i8; - } - template <> - constexpr auto GetMfma() - { - return MfmaInstr::mfma_i32_16x16x32i8; - } #else - template <> - constexpr auto GetMfma() - { return MfmaInstr::mfma_i32_32x32x8i8; - } - template <> - constexpr auto GetMfma() - { - return MfmaInstr::mfma_i32_16x16x16i8; - } #endif + } + + template <> + constexpr auto GetMfma() + { +#if defined(__gfx942__) || defined(__gfx950__) + return MfmaInstr::mfma_i32_32x32x16i8; +#else + return MfmaInstr::mfma_i32_32x32x8i8; +#endif + } + + template <> + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_i32_16x16x64i8; +#elif defined(__gfx942__) + return MfmaInstr::mfma_i32_16x16x32i8; +#else + return MfmaInstr::mfma_i32_16x16x16i8; +#endif + } + + template <> + constexpr auto GetMfma() + { +#if defined(__gfx942__) || defined(__gfx950__) + return MfmaInstr::mfma_i32_16x16x32i8; +#else + return MfmaInstr::mfma_i32_16x16x16i8; +#endif + } template <> constexpr auto GetMfma() @@ -1440,12 +1449,13 @@ struct XdlopsGemm } // Falls back to single rate instruction on gfx950 if KPack <= 4; no change on gfx942- - static constexpr auto - mfma = MfmaSelector < base_type, - MPerXdlops, NPerXdlops, additional_type, - ((is_same::value || is_same::value) && KPack <= 4) - ? true - : false > {}; + static constexpr auto mfma = MfmaSelector < base_type, MPerXdlops, NPerXdlops, additional_type, + (((is_same::value || + is_same::value) && + KPack <= 4) || + (is_same::value && KPack <= 8)) + ? true + : false > {}; static constexpr auto mfma_instr = mfma.selected_mfma; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp index 3ebfd0bb7d..bea22da2c2 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp @@ -41,13 +41,11 @@ template using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_generic_instances = std::tuple< -// clang-format off + // clang-format off //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups| //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1> // clang-format on >; @@ -60,13 +58,11 @@ template using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_instances = std::tuple< -// clang-format off + // clang-format off //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups| //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2>, @@ -110,13 +106,11 @@ template using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_generic_instances = std::tuple< -// clang-format off + // clang-format off //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups| //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1> // clang-format on >; @@ -129,13 +123,11 @@ template using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_instances = std::tuple< -// clang-format off + // clang-format off //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups| //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2>, @@ -179,13 +171,11 @@ template using device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_generic_instances = std::tuple< -// clang-format off + // clang-format off //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups| //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1, F16, F16, 1, 1> // clang-format on >; @@ -199,13 +189,11 @@ template using device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances = std::tuple< -// clang-format off + // clang-format off //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups| //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1, F16, F16, 1, 1>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2, F16, F16, 2, 2>, @@ -241,13 +229,11 @@ template using device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_generic_instances = std::tuple< -// clang-format off + // clang-format off //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups| //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1, BF16, BF16, 1, 1> // clang-format on >; @@ -260,13 +246,11 @@ template using device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_instances = std::tuple< -// clang-format off + // clang-format off //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups| //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1, BF16, BF16, 1, 1>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2, BF16, BF16, 2, 2>, diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp index f9b3b43795..e7bbf8a26a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp @@ -49,6 +49,23 @@ static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; +// double rate mfma instances on gfx950 +template +using device_grouped_conv_fwd_xdl_bf16_comp_instances_2x = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + template using device_grouped_conv_fwd_xdl_bf16_comp_instances = std::tuple< -// clang-format off + // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> -#else // defined(CK_USE_AMD_MFMA_GFX950) // Compute friendly DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, @@ -71,18 +85,45 @@ using device_grouped_conv_fwd_xdl_bf16_comp_instances = std::tuple< DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - // AGPR Spill when use permuted lds layout. so, use padding for these two. - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3> -#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_bf16_comp_instances_part2 = std::tuple< + // clang-format off + // AGPR Spill when use permuted lds layout. so, use padding for these two. + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5> + // clang-format on + >; + +// double rate mfma instances on gfx950 +template +using device_grouped_conv_fwd_xdl_f16_comp_instances_2x = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> // clang-format on >; @@ -93,16 +134,24 @@ template using device_grouped_conv_fwd_xdl_f16_comp_instances = std::tuple< -// clang-format off + // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> -#else // defined(CK_USE_AMD_MFMA_GFX950) + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4> + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_f16_comp_instances_part2 = std::tuple< + // clang-format off DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, @@ -115,7 +164,6 @@ using device_grouped_conv_fwd_xdl_f16_comp_instances = std::tuple< DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> -#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; @@ -138,6 +186,23 @@ using device_grouped_conv_fwd_xdl_f32_comp_instances = std::tuple< // clang-format on >; +// double rate mfma instances on gfx950 +template +using device_grouped_conv_fwd_xdl_int8_comp_instances_2x = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + template using device_grouped_conv_fwd_xdl_int8_comp_instances = std::tuple< -// clang-format off + // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> -#else // defined(CK_USE_AMD_MFMA_GFX950) + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4> + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_int8_comp_instances_part2 = std::tuple< + // clang-format off DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, @@ -163,10 +236,8 @@ using device_grouped_conv_fwd_xdl_int8_comp_instances = std::tuple< DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> -#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; - } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp index 9114d5c1fb..153cc61b09 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp @@ -40,21 +40,34 @@ template using device_grouped_conv_fwd_xdl_merged_groups_bf16_instances = std::tuple< -// clang-format off + // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ACompute| BCompute| BlockGemm| NumGroups| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Type| Type| Pipeline| ToMerge| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | Scheduler| | //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 16>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 32> -#else // defined(CK_USE_AMD_MFMA_GFX950) // Instances with NumGroupsPerBatch > 1 DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 8>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 16>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 32> -#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on + >; + +// double rate mfma instances on gfx950 +template +using device_grouped_conv_fwd_xdl_merged_groups_bf16_instances_2x = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ACompute| BCompute| BlockGemm| NumGroups| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Type| Type| Pipeline| ToMerge| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | Scheduler| | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 16>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 32> // clang-format on >; @@ -65,22 +78,35 @@ template using device_grouped_conv_fwd_xdl_merged_groups_f16_instances = std::tuple< -// clang-format off + // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) - // Instances with NumGroupsPerBatch > 1 - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 16>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 32> -#else // defined(CK_USE_AMD_MFMA_GFX950) // Instances with NumGroupsPerBatch > 1 DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 16>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 32> -#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on + >; + +// double rate mfma instances on gfx950 +template +using device_grouped_conv_fwd_xdl_merged_groups_f16_instances_2x = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // Instances with NumGroupsPerBatch > 1 + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 16>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 32> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp index dad67b396f..f3a32bab54 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp @@ -8,6 +8,7 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -27,15 +28,12 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; // Compilation parameters for a[k, m] * b[k, n] = c[m, n] using device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances = std::tuple< -// clang-format off + // clang-format off //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumGemmK| LoopScheduler| Pipeline| //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| | | //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Stage | | | //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, @@ -70,6 +68,17 @@ using device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances = std::tuple< #endif // clang-format on >; +// double rate mfma instances on gfx950 +using device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances_2x = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumGemmK| LoopScheduler| Pipeline| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| | | + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Stage | | | + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances( std::vector, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, @@ -70,6 +68,17 @@ using device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances = std::tuple< #endif // clang-format on >; +// double rate mfma instances on gfx950 +using device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances_2x = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumGemmK| LoopScheduler| Pipeline| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| | | + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Stage | | | + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances( std::vector; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_generic_instances = std::tuple< -// clang-format off + // clang-format off //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumGemmK| LoopScheduler| Pipeline| //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| | | //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Stage | | | //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 64, 16, 16, 4, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> // clang-format on >; +// double rate mfma instances on gfx950 +using device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_generic_instances_2x = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumGemmK| LoopScheduler| Pipeline| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| | | + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Stage | | | + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 64, 16, 16, 4, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; // Compilation parameters for a[m, k] * b[k, n] = c[m, n] using device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances = std::tuple< -// clang-format off + // clang-format off //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumGemmK| LoopScheduler| Pipeline| //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| | | //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Stage | | | //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, @@ -110,6 +115,17 @@ using device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances = std::tuple< #endif // clang-format on >; +// double rate mfma instances on gfx950 +using device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances_2x = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumGemmK| LoopScheduler| Pipeline| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| | | + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Stage | | | + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances( std::vector; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_generic_instances = std::tuple< -// clang-format off + // clang-format off //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumGemmK| LoopScheduler| Pipeline| //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| | | //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Stage | | | //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> // clang-format on >; +// double rate mfma instances on gfx950 +using device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_generic_instances_2x = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumGemmK| LoopScheduler| Pipeline| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| | | + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Stage | | | + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; // Compilation parameters for a[m, k] * b[n, k] = c[m, n] using device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances = std::tuple< -// clang-format off + // clang-format off //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumGemmK| LoopScheduler| Pipeline| //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| | | //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Stage | | | //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, @@ -98,6 +103,17 @@ using device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances = std::tuple< #endif // clang-format on >; +// double rate mfma instances on gfx950 +using device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances_2x = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumGemmK| LoopScheduler| Pipeline| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| | | + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Stage | | | + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances( std::vector using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances = std::tuple< -// clang-format off + // clang-format off //#######################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskOut| //#######################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Upper| //#######################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Triangle| //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, @@ -61,13 +59,11 @@ using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_ template using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_irregular_k_instances = std::tuple< -// clang-format off + // clang-format off //#######################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskOut| //#######################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Upper| //#######################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Triangle| //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 256, 128, 40, 64, 32, 4, 4, 2, 32, 32, 2, 4, 2, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 256, 128, 40, 128, 32, 4, 4, 2, 32, 32, 2, 4, 4, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 256, 40, 64, 32, 4, 4, 2, 32, 32, 1, 8, 2, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp index 8382f069d7..498bf58fb3 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp @@ -40,13 +40,11 @@ template using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances = std::tuple< -// clang-format off + // clang-format off // #############################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| AData| B0Data| B1Data| CData| Acc0BiasData| Acc1BiasData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskingSpec| D0s Bias| // #############################################| | | | | | Type| Type| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | SrcScalar| // #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | PerVector| // #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec, 1>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp index b6c14d69db..744bd6456d 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp @@ -40,13 +40,11 @@ template using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances = std::tuple< -// clang-format off + // clang-format off // #############################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| AData| B0Data| B1Data| CData| Acc0BiasData| Acc1BiasData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskingSpec| D0s Bias| // #############################################| | | | | | Type| Type| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | SrcScalar| // #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | PerVector| // #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec, 1>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp index 2ce5124706..b342612d1c 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp @@ -40,13 +40,11 @@ template using device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances = std::tuple< -// clang-format off + // clang-format off // #############################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| AData| B0Data| B1Data| CData| Acc0BiasData| Acc1BiasData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskingSpec| // #############################################| | | | | | Type| Type| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | // #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | // #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp index 4e8adceb1c..6d64a2e2d6 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp @@ -8,6 +8,7 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -58,15 +59,26 @@ using device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_ DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>, -#if !defined(CK_USE_AMD_MFMA_GFX950) - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>, -#endif // !defined(CK_USE_AMD_MFMA_GFX950) // Padded fallback kernel DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec> // clang-format on >; +// instances not working on gfx950 +template +using device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances_part2 = + std::tuple< + // clang-format off + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec> + // clang-format on + >; + void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( std::vector{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances_part2< + 2, + 1, + 1, + 1, + 1, + MaskingSpecialization::MaskOutUpperTriangle>{}); + } } void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( @@ -129,6 +154,19 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_g 1, 1, MaskingSpecialization::MaskDisabled>{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances_part2< + 2, + 1, + 1, + 1, + 1, + MaskingSpecialization::MaskDisabled>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp index c0c74ff7fb..86c17aacf0 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp @@ -8,6 +8,7 @@ #include "ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -40,14 +41,11 @@ static constexpr auto ConvFwdOddC = // arbitrary conv using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances = std::tuple< -// clang-format off + // clang-format off //##########################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, @@ -63,17 +61,24 @@ using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances = std::tuple< DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> // clang-format on >; - -// 1x1, pad 0 -using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std::tuple< -// clang-format off +// double rate mfma instances on gfx950 +using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_2x = std::tuple< + // clang-format off + //##########################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +// 1x1, pad 0 +using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std::tuple< + // clang-format off //##########################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, @@ -89,17 +94,24 @@ using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std: DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> // clang-format on >; - -// 1x1, stride 1, pad 0 -using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple< -// clang-format off +// double rate mfma instances on gfx950 +using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_1x1_p0_f16_instances_2x = std::tuple< + // clang-format off + //##########################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +// 1x1, stride 1, pad 0 +using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off //##########################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, @@ -115,16 +127,23 @@ using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = s DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> // clang-format on >; - -using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_odd_c_f16_instances = std::tuple< -// clang-format off +// double rate mfma instances on gfx950 +using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances_2x = std::tuple< + // clang-format off + //##########################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_odd_c_f16_instances = std::tuple< + // clang-format off //##########################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, @@ -145,6 +164,16 @@ using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_odd_c_f16_instances = std:: DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 64, 64, 2, 4, 32, 32, 1, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> // clang-format on >; +// double rate mfma instances on gfx950 +using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_odd_c_f16_instances_2x = std::tuple< + // clang-format off + //##########################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances( std::vector, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, @@ -62,17 +60,24 @@ using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances = s DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> // clang-format on >; - -// 1x1, pad 0 -using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std::tuple< -// clang-format off +// double rate mfma instances on gfx950 +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances_2x = std::tuple< + // clang-format off + //##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +// 1x1, pad 0 +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std::tuple< + // clang-format off //##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, @@ -88,17 +93,24 @@ using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_1x1_p0_f16_instan DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> // clang-format on >; - -// 1x1, stride 1, pad 0 -using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple< -// clang-format off +// double rate mfma instances on gfx950 +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_1x1_p0_f16_instances_2x = std::tuple< + // clang-format off + //##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +// 1x1, stride 1, pad 0 +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off //##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, @@ -114,17 +126,25 @@ using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_1x1_s1_p0_f16_ins DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> // clang-format on >; - -// Odd C -using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_odd_c_f16_instances = std::tuple< -// clang-format off +// double rate mfma instances on gfx950 +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances_2x = + std::tuple< + // clang-format off + //##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +// Odd C +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_odd_c_f16_instances = std::tuple< + // clang-format off //##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, @@ -145,6 +165,16 @@ using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_odd_c_f16_instanc DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 64, 64, 2, 4, 32, 32, 1, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> // clang-format on >; +// double rate mfma instances on gfx950 +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_odd_c_f16_instances_2x = std::tuple< + // clang-format off + //##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances( std::vector>& instances) @@ -158,6 +188,21 @@ void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances( device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{}); add_device_operation_instances( instances, device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_odd_c_f16_instances{}); + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances_2x{}); + add_device_operation_instances( + instances, + device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_1x1_p0_f16_instances_2x{}); + add_device_operation_instances( + instances, + device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances_2x{}); + add_device_operation_instances( + instances, + device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_odd_c_f16_instances_2x{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp index 5676d77986..faac2813ba 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp @@ -8,6 +8,7 @@ #include "ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -37,14 +38,11 @@ static constexpr auto ConvFwdOddC = // arbitrary conv using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances = std::tuple< -// clang-format off + // clang-format off //##############################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //##############################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //##############################################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, @@ -60,17 +58,24 @@ using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> // clang-format on >; - -// 1x1, pad 0 -using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std::tuple< -// clang-format off +// double rate mfma instances on gfx950 +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances_2x = std::tuple< + // clang-format off + //##############################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##############################################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +// 1x1, pad 0 +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std::tuple< + // clang-format off //##############################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //##############################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //##############################################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, @@ -86,17 +91,24 @@ using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_1x1_p0_f16_in DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> // clang-format on >; - -// 1x1, stride 1, pad 0 -using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple< -// clang-format off +// double rate mfma instances on gfx950 +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_1x1_p0_f16_instances_2x = std::tuple< + // clang-format off + //##############################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##############################################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +// 1x1, stride 1, pad 0 +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off //##############################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //##############################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //##############################################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, @@ -112,17 +124,25 @@ using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_1x1_s1_p0_f16 DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> // clang-format on >; - -// Odd C -using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_odd_c_f16_instances = std::tuple< -// clang-format off +// double rate mfma instances on gfx950 +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances_2x = + std::tuple< + // clang-format off + //##############################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##############################################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +// Odd C +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_odd_c_f16_instances = std::tuple< + // clang-format off //##############################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //##############################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //##############################################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, @@ -143,6 +163,16 @@ using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_odd_c_f16_ins DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 64, 64, 2, 4, 32, 32, 1, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> // clang-format on >; +// double rate mfma instances on gfx950 +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_odd_c_f16_instances_2x = std::tuple< + // clang-format off + //##############################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##############################################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances( std::vector>& instances) @@ -158,6 +188,22 @@ void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instan add_device_operation_instances( instances, device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_odd_c_f16_instances{}); + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances_2x{}); + add_device_operation_instances( + instances, + device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_1x1_p0_f16_instances_2x{}); + add_device_operation_instances( + instances, + device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances_2x{}); + add_device_operation_instances( + instances, + device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_odd_c_f16_instances_2x{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp index 0d143b95ee..7fcab3ea46 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp @@ -8,6 +8,7 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -29,15 +30,12 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa // Compilation parameters for a[m, k] * b[n, k] = c[m, n] using device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances = std::tuple< -// clang-format off + // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | Version| //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, @@ -87,6 +85,17 @@ using device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances = std::tu #endif // clang-format on >; +// double rate mfma instances on gfx950 +using device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances_2x = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | Version| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( std::vector, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV2< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 2, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> -#if defined(CK_USE_AMD_MFMA_GFX950) - , - DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 64, 128, 32, 32, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, PipelineVersion::v1> -#endif #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves , @@ -96,6 +92,18 @@ using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances = std::tuple< #endif // clang-format on >; +// double rate mfma instances on gfx950 +template +using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances_2x = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 64, 128, 32, 32, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances( std::vector{}); + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances_2x{}); + + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances_2x{}); + + add_device_operation_instances( + instances, + device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances_2x{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp index 3ebd0c5351..27240122de 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp @@ -10,6 +10,7 @@ #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -31,15 +32,12 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa // Compilation parameters for a[k, m] * b[k, n] = c[m, n] using device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances = std::tuple< -// clang-format off + // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, @@ -98,6 +96,17 @@ using device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances = std::tuple< #endif // clang-format on >; +// double rate mfma instances on gfx950 +using device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances_2x = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( std::vector, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, @@ -98,6 +96,17 @@ using device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances = std::tuple< #endif // clang-format on >; +// double rate mfma instances on gfx950 +using device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances_2x = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( std::vector, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, 1, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, MNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, 1, LoopScheduler::Default, PipelineVersion::v1> // clang-format on >; +// double rate mfma instances on gfx950 +using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_generic_instances_2x = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, MNKPadding, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; // Compilation parameters for a[m, k] * b[k, n] = c[m, n] template using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances = std::tuple< -// clang-format off + // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) // pipeline v1, 1 wave DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, @@ -133,6 +138,12 @@ void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( add_device_operation_instances( instances, device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances{}); + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_generic_instances_2x{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp index 852b053527..278e56f556 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp @@ -11,6 +11,7 @@ #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -34,30 +35,35 @@ static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecializa // Compilation parameters for a[m, k] * b[n, k] = c[m, n] using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_generic_instances = std::tuple< -// clang-format off + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, MNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; +// double rate mfma instances on gfx950 +using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_generic_instances_2x = std::tuple< + // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) //DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, //DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 128, 32, 32, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, PipelineVersion::v1> - DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, MNKPadding, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, 1, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, MNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, 1, LoopScheduler::Default, PipelineVersion::v1> + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, MNKPadding, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, 1, LoopScheduler::Default, PipelineVersion::v1> // clang-format on >; template // Compilation parameters for a[m, k] * b[n, k] = c[m, n] using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances = std::tuple< -// clang-format off + // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) // pipeline v1, 1 wave DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, @@ -125,6 +131,12 @@ void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( add_device_operation_instances( instances, device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances{}); + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_generic_instances_2x{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp index 41f6ec2bf7..caf17d55cb 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp @@ -8,6 +8,8 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + #ifdef CK_ENABLE_INT8 namespace ck { namespace tensor_operation { @@ -47,12 +49,19 @@ using device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances = DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 32, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 16>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 64, 64, 16, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 16> -#if defined(CK_USE_AMD_MFMA_GFX950) - , + // clang-format on + >; +// double rate mfma instances on gfx950 +using device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances_2x = + std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 128, 32, 32, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 256, 64, 64, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, PipelineVersion::v1> -#endif // defined(CK_USE_AMD_MFMA_GFX950) - // clang-format on + // clang-format on >; void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances( @@ -62,6 +71,12 @@ void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances( { add_device_operation_instances(instances, device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances{}); + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances_2x{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp index 74cf837500..a64424e8ac 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp @@ -9,7 +9,8 @@ namespace device { namespace instance { // Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using Instances = std::tuple< +using Instances = + std::tuple< // clang-format off #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -17,8 +18,6 @@ using Instances = std::tuple< //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, @@ -28,8 +27,8 @@ using Instances = std::tuple< DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> #endif - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_interwave_pipeline_v1_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_interwave_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_interwave_pipeline_v1_instance.cpp index f2b28f3b40..0a0406baec 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_interwave_pipeline_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_interwave_pipeline_v1_instance.cpp @@ -17,8 +17,6 @@ using Instances = std::tuple< //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> #endif // clang-format on diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp index da5fefe5da..3671bea7a3 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp @@ -9,7 +9,8 @@ namespace device { namespace instance { // Compilation parameters for a[k, m] * b[n, k] = c[m, n] -using Instances = std::tuple< +using Instances = + std::tuple< // clang-format off #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -17,8 +18,6 @@ using Instances = std::tuple< //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, @@ -28,8 +27,8 @@ using Instances = std::tuple< DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> #endif - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_interwave_pipeline_v1_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_interwave_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_interwave_pipeline_v1_instance.cpp index b6c03b3367..95fc8ecb46 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_interwave_pipeline_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_interwave_pipeline_v1_instance.cpp @@ -17,8 +17,6 @@ using Instances = std::tuple< //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> #endif // clang-format on diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp index bf271cc3c3..fa53a3bf0f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp @@ -9,7 +9,8 @@ namespace device { namespace instance { // Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using Instances = std::tuple< +using Instances = + std::tuple< // clang-format off #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -17,8 +18,6 @@ using Instances = std::tuple< //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, @@ -37,8 +36,8 @@ using Instances = std::tuple< DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> #endif - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_interwave_pipeline_v1_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_interwave_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_interwave_pipeline_v1_instance.cpp index 0df59933c2..c9d1913aec 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_interwave_pipeline_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_interwave_pipeline_v1_instance.cpp @@ -17,8 +17,6 @@ using Instances = std::tuple< //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> #endif // clang-format on diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_interwave_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_interwave_pipeline_v1_instance.cpp index d9260d85ab..0410eabb70 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_interwave_pipeline_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_interwave_pipeline_v1_instance.cpp @@ -17,8 +17,6 @@ using Instances = std::tuple< //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_interwave_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_interwave_pipeline_v1_instance.cpp index 8b98133ada..a41919aab7 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_interwave_pipeline_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_interwave_pipeline_v1_instance.cpp @@ -17,8 +17,6 @@ using Instances = std::tuple< //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp index 9ea79b1467..33ee841354 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp @@ -10,6 +10,7 @@ #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -38,29 +39,36 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // input: a[k, m], b[k, n], d0[m, n], d1[m, n] using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_generic_instance = std::tuple< -// clang-format off + // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> // clang-format on >; -using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances = +// double rate mfma instances on gfx950 +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_generic_instance_2x = std::tuple< -// clang-format off + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances = + std::tuple< + // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, @@ -119,19 +127,28 @@ using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn #endif // clang-format on >; - -// irregular tile size -using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_irregular_tile_instances = +// double rate mfma instances on gfx950 +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances_2x = std::tuple< -// clang-format off + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + +// irregular tile size +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_irregular_tile_instances = + std::tuple< + // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -145,6 +162,18 @@ using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn #endif // clang-format on >; +// double rate mfma instances on gfx950 +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_irregular_tile_instances_2x = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances( std::vector, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> // clang-format on >; -using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances = +// double rate mfma instances on gfx950 +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_generic_instance_2x = std::tuple< -// clang-format off + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances = + std::tuple< + // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, @@ -119,19 +127,28 @@ using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn #endif // clang-format on >; - -// irregular tile size -using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_irregular_tile_instances = +// double rate mfma instances on gfx950 +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances_2x = std::tuple< -// clang-format off + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + +// irregular tile size +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_irregular_tile_instances = + std::tuple< + // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -145,6 +162,18 @@ using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn #endif // clang-format on >; +// double rate mfma instances on gfx950 +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_irregular_tile_instances_2x = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances( std::vector, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> // clang-format on >; -using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances = +// double rate mfma instances on gfx950 +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_generic_instance_2x = std::tuple< -// clang-format off + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 128, 128, 128, 64, 16, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances = + std::tuple< + // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, @@ -119,19 +127,28 @@ using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn #endif // clang-format on >; - -// irregular tile size -using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_irregular_tile_instances = +// double rate mfma instances on gfx950 +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances_2x = std::tuple< -// clang-format off + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + +// irregular tile size +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_irregular_tile_instances = + std::tuple< + // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -145,6 +162,18 @@ using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn #endif // clang-format on >; +// double rate mfma instances on gfx950 +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_irregular_tile_instances_2x = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances( std::vector, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> // clang-format on >; -using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances = +// double rate mfma instances on gfx950 +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_generic_instance_2x = std::tuple< -// clang-format off + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances = + std::tuple< + // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, @@ -110,19 +118,28 @@ using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn #endif // clang-format on >; - -// irregular tile size -using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_irregular_tile_instances = +// double rate mfma instances on gfx950 +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances_2x = std::tuple< -// clang-format off + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + +// irregular tile size +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_irregular_tile_instances = + std::tuple< + // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -136,6 +153,18 @@ using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn #endif // clang-format on >; +// double rate mfma instances on gfx950 +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_irregular_tile_instances_2x = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances( std::vector, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> // clang-format on >; +// double rate mfma instances on gfx950 +using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_generic_instance_2x = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances = std::tuple< -// clang-format off + // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, @@ -103,19 +111,27 @@ using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instanc #endif // clang-format on >; +// double rate mfma instances on gfx950 +using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances_2x = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; // irregular tile size using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_irregular_tile_instances = std::tuple< -// clang-format off + // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -129,6 +145,18 @@ using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_irregul #endif // clang-format on >; +// double rate mfma instances on gfx950 +using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_irregular_tile_instances_2x = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances( std::vector, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> // clang-format on >; +// double rate mfma instances on gfx950 +using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_generic_instance_2x = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances = std::tuple< -// clang-format off + // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, @@ -103,19 +111,27 @@ using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instanc #endif // clang-format on >; +// double rate mfma instances on gfx950 +using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances_2x = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; // irregular tile size using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_irregular_tile_instances = std::tuple< -// clang-format off + // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -129,6 +145,18 @@ using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_irregul #endif // clang-format on >; +// double rate mfma instances on gfx950 +using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_irregular_tile_instances_2x = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances( std::vector, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> // clang-format on >; +// double rate mfma instances on gfx950 +using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_generic_instance_2x = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std::tuple< -// clang-format off + // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, @@ -103,19 +111,27 @@ using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instanc #endif // clang-format on >; +// double rate mfma instances on gfx950 +using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances_2x = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; // irregular tile size using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_irregular_tile_instances = std::tuple< -// clang-format off + // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -129,6 +145,18 @@ using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_irregul #endif // clang-format on >; +// double rate mfma instances on gfx950 +using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_irregular_tile_instances_2x = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( std::vector, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> // clang-format on >; +// double rate mfma instances on gfx950 +using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_generic_instance_2x = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances = std::tuple< -// clang-format off + // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, @@ -94,19 +102,27 @@ using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instanc #endif // clang-format on >; +// double rate mfma instances on gfx950 +using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances_2x = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; // irregular tile size using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_irregular_tile_instances = std::tuple< -// clang-format off + // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -120,6 +136,18 @@ using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_irregul #endif // clang-format on >; +// double rate mfma instances on gfx950 +using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_irregular_tile_instances_2x = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances( std::vector using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances = std::tuple< -// clang-format off + // clang-format off //#######################################| A| B| Ds| H| AData| BData| AccData| CShuffle| DsData| EMeanVarData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm| LoopScheduler| Pipeline| //#######################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize| | | //#######################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _NWaveNPerXdl| _M_N| _M| | | //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, @@ -66,13 +64,11 @@ using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instan // irregular tile size using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_irregular_tile_instances = std::tuple< -// clang-format off + // clang-format off //#######################################| A| B| Ds| H| AData| BData| AccData| CShuffle| DsData| EMeanVarData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm| LoopScheduler| Pipeline| //#######################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize| | | //#######################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _NWaveNPerXdl| _M_N| _M| | | //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) // pipeline v1, 1 wave DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<16, 4>, 1, S<16, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp index ef37c82c7f..13366238d6 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp @@ -37,13 +37,11 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // input: a[k, m], b[k, n], d0[m, n], d1[m, n], gamma[n], beta[n] template using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances = std::tuple< -// clang-format off + // clang-format off //#######################################| A| B| Ds| H| AData| BData| AccData| CShuffle| DsData| EMeanVarData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm| LoopScheduler| Pipeline| //#######################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize| | | //#######################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _NWaveNPerXdl| _M_N| _M| | | //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, @@ -66,13 +64,11 @@ using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instan // irregular tile size using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_irregular_tile_instances = std::tuple< -// clang-format off + // clang-format off //#######################################| A| B| Ds| H| AData| BData| AccData| CShuffle| DsData| EMeanVarData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm| LoopScheduler| Pipeline| //#######################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize| | | //#######################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _NWaveNPerXdl| _M_N| _M| | | //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) // pipeline v1, 1 wave DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<16, 4>, 1, S<16, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp index 40fbc85be0..8a4889ee83 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp @@ -37,13 +37,11 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // input: a[k, m], b[k, n], d0[m, n], d1[m, n], gamma[n], beta[n] template using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances = std::tuple< -// clang-format off + // clang-format off //#######################################| A| B| Ds| H| AData| BData| AccData| CShuffle| DsData| EMeanVarData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm| LoopScheduler| Pipeline| //#######################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize| | | //#######################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _NWaveNPerXdl| _M_N| _M| | | //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, @@ -66,13 +64,11 @@ using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instan // irregular tile size using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_irregular_tile_instances = std::tuple< -// clang-format off + // clang-format off //#######################################| A| B| Ds| H| AData| BData| AccData| CShuffle| DsData| EMeanVarData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm| LoopScheduler| Pipeline| //#######################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize| | | //#######################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _NWaveNPerXdl| _M_N| _M| | | //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) // pipeline v1, 1 wave DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<16, 4>, 1, S<16, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp index 464279c376..fc3cbcf905 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp @@ -37,13 +37,11 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // input: a[k, m], b[k, n], d0[m, n], d1[m, n], gamma[n], beta[n] template using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances = std::tuple< -// clang-format off + // clang-format off //#######################################| A| B| Ds| H| AData| BData| AccData| CShuffle| DsData| EMeanVarData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm| LoopScheduler| Pipeline| //#######################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize| | | //#######################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _NWaveNPerXdl| _M_N| _M| | | //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<16, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>, @@ -63,13 +61,11 @@ using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instan // irregular tile size using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_irregular_tile_instances = std::tuple< -// clang-format off + // clang-format off //#######################################| A| B| Ds| H| AData| BData| AccData| CShuffle| DsData| EMeanVarData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm| LoopScheduler| Pipeline| //#######################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize| | | //#######################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _NWaveNPerXdl| _M_N| _M| | | //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) // pipeline v1, 1 wave DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<16, 4>, 1, S<16, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES diff --git a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp index e2bf62ca94..c5e82abed6 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp @@ -6,6 +6,7 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/utility/sequence.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -22,28 +23,34 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // outout: e[m, n] // input: a[k, m], b[k, n] using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_generic_instance = std::tuple< -// clang-format off + // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> // clang-format on >; -using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances = std::tuple< -// clang-format off +// double rate mfma instances on gfx950 +using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_generic_instance_2x = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + +using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances = std::tuple< + // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, @@ -102,18 +109,26 @@ using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances = std::t #endif // clang-format on >; - -// irregular tile size -using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_irregular_tile_instances = std::tuple< -// clang-format off +// double rate mfma instances on gfx950 +using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances_2x = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + +// irregular tile size +using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_irregular_tile_instances = std::tuple< + // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -127,6 +142,18 @@ using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_irregular_tile_ins #endif // clang-format on >; +// double rate mfma instances on gfx950 +using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_irregular_tile_instances_2x = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; void add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( std::vector, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> // clang-format on >; -using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances = std::tuple< -// clang-format off +// double rate mfma instances on gfx950 +using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_generic_instance_2x = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + +using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, @@ -102,18 +109,26 @@ using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances = std::t #endif // clang-format on >; - -// irregular tile size -using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_irregular_tile_instances = std::tuple< -// clang-format off +// double rate mfma instances on gfx950 +using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances_2x = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + +// irregular tile size +using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_irregular_tile_instances = std::tuple< + // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -127,6 +142,18 @@ using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_irregular_tile_ins #endif // clang-format on >; +// double rate mfma instances on gfx950 +using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_irregular_tile_instances_2x = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; void add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( std::vector, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> // clang-format on >; +// double rate mfma instances on gfx950 +using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_generic_instance_2x = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances = std::tuple< -// clang-format off + // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, @@ -102,18 +109,26 @@ using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances = std::t #endif // clang-format on >; +// double rate mfma instances on gfx950 +using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances_2x = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; // irregular tile size using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_irregular_tile_instances = std::tuple< -// clang-format off + // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -127,6 +142,18 @@ using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_irregular_tile_ins #endif // clang-format on >; +// double rate mfma instances on gfx950 +using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_irregular_tile_instances_2x = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; void add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( std::vector, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> // clang-format on >; +// double rate mfma instances on gfx950 +using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_generic_instance_2x = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances = std::tuple< -// clang-format off + // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, @@ -93,18 +100,26 @@ using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances = std::t #endif // clang-format on >; +// double rate mfma instances on gfx950 +using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances_2x = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; // irregular tile size using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_irregular_tile_instances = std::tuple< -// clang-format off + // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave -#if defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -118,6 +133,18 @@ using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_irregular_tile_ins #endif // clang-format on >; +// double rate mfma instances on gfx950 +using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_irregular_tile_instances_2x = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; void add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( std::vector using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_iw_instances = std::tuple< -// clang-format off + // clang-format off //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) //PipelineVersion::v1; interwave DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_irregular_instance.cpp index 0b48bbf606..6a323d323f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_irregular_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_irregular_instance.cpp @@ -34,13 +34,11 @@ template using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances = std::tuple< -// clang-format off + // clang-format off //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_instance.cpp index 422db05b35..2855235f97 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_instance.cpp @@ -33,14 +33,12 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // Compilation parameters for a[m, k] * b[k, n] = c[m, n] template using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_iw_instances = std::tuple< -// clang-format off + // clang-format off //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //PipelineVersion::v1; interwave -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp index 5540d2d884..59154f3439 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp @@ -36,13 +36,11 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template using device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_instances = std::tuple< -// clang-format off + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, @@ -60,13 +58,11 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_instances = std::tu template using device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_instances = std::tuple< -// clang-format off + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) // Latency friendly DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp index d9f9969621..6f0d722b60 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp @@ -36,17 +36,12 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template using device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = std::tuple< -// clang-format off + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) // Compute friendly -#if !defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 8, 32, 32, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, -#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 32, 32, 4, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, @@ -56,9 +51,6 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = std::tu DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 8, 32, 32, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 32, 32, 4, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, -#if !defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, -#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, @@ -66,16 +58,22 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = std::tu DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 2, 2, 32, 32, 2, 2, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> // clang-format on >; +// instances not working on gfx950 +template +using device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances_part2 = std::tuple< + // clang-format off + DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 8, 32, 32, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; template using device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_instances = std::tuple< -// clang-format off + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) // Latency friendly DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 8, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, @@ -90,15 +88,19 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_instances = std::tup DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 8, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 4, 8, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 4, 8, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, -#if !defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 4, 8, 16, 16, 1, 2, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, -#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 4, 4, 16, 16, 1, 2, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 4, 8, 16, 16, 1, 4, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 2, 8, 16, 16, 1, 4, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 2, 2, 16, 16, 1, 4, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> // clang-format on >; +// instances not working on gfx950 +template +using device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_instances_part2 = std::tuple< + // clang-format off + DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 4, 8, 16, 16, 1, 2, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instance.cpp index 02272e84b3..5396c16cb8 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -15,6 +16,13 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instance { add_device_operation_instances( instances, device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instance.cpp index 89acbb6f68..5cc2d987bc 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -16,6 +17,13 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instanc add_device_operation_instances( instances, device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instance.cpp index 52227620b4..3f7e2fe302 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -16,6 +17,14 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instan add_device_operation_instances( instances, device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances_part2< + GemmMKPadding>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mpadding_instance.cpp index 311b1d0b28..550fbce684 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -16,6 +17,13 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mpadding_instanc add_device_operation_instances( instances, device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_default_instance.cpp index 40a9239ad8..01fcabc872 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_default_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -16,6 +17,14 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_default_instan add_device_operation_instances( instances, device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_kpadding_instance.cpp index 3a9dc7b081..22ad32465e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_kpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -16,6 +17,14 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_kpadding_insta add_device_operation_instances( instances, device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_mkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_mkpadding_instance.cpp index 8fef199b25..335bbb6d36 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_mkpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_mkpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -17,6 +18,14 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_mkpadding_inst instances, device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_default_instance.cpp index 4f2c95e20d..3102f7c158 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_default_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -16,6 +17,14 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_default_instan add_device_operation_instances( instances, device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_kpadding_instance.cpp index c58cfb592f..7c0ce3019f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_kpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -16,6 +17,14 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_kpadding_insta add_device_operation_instances( instances, device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instance.cpp index 7cd3bfc882..e54c611e73 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -17,6 +18,14 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_inst instances, device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp index 0619a98cf0..50fdca9348 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp @@ -36,39 +36,37 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template using device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = std::tuple< -// clang-format off + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) -#if !defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, -#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 32, 32, 4, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, -#if !defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 2, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, -#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> // clang-format on >; +// instances not working on gfx950 +template +using device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances_part2 = std::tuple< + // clang-format off + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 2, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; template using device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_instances = std::tuple< -// clang-format off + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) // Latency friendly DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 4, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 8, 4, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, @@ -83,13 +81,18 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_instances = std::tup DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, -#if !defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 4, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, -#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> // clang-format on >; +// instances not working on gfx950 +template +using device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_instances_part2 = std::tuple< + // clang-format off + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 4, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // clang-format on + >; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp index 246e7a5067..e86d0a6ff9 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -15,6 +16,13 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance { add_device_operation_instances( instances, device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instance.cpp index 9737dd5f0c..5f619d2850 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -16,6 +17,13 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instanc add_device_operation_instances( instances, device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp index 01f1315646..6a1ad1e890 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -16,6 +17,14 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_insta add_device_operation_instances( instances, device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances_part2< + GemmMNKPadding>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instance.cpp index 0d1cb4f25b..65c1cbce96 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -16,6 +17,14 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instan add_device_operation_instances( instances, device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances_part2< + GemmMNPadding>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instance.cpp index 377e2f90a6..d98dd7b99d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -16,6 +17,14 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instan add_device_operation_instances( instances, device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instance.cpp index 2ad1d1e52f..9da5e8b1cd 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -16,6 +17,14 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_insta add_device_operation_instances( instances, device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp index f82fb92302..fea562d208 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -17,6 +18,15 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_ins instances, device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_instances_part2< + Intrawave, + GemmMNKPadding>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instance.cpp index 2f8abf0a88..cc05571615 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -16,6 +17,14 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instan add_device_operation_instances( instances, device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instance.cpp index 2e0f670aad..939adc4c90 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -16,6 +17,14 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_insta add_device_operation_instances( instances, device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp index d7dc599748..0f3d081cf7 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -17,6 +18,15 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_ins instances, device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_instances_part2< + Interwave, + GemmMNKPadding>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp index 3edbd28cd8..7d141a47e1 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp @@ -36,44 +36,42 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template using device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances = std::tuple< -// clang-format off + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) // Compute friendly -#if !defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, -#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 32, 32, 4, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, -#if !defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, - // AGPR Spill - // DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, - // AGPR Spill when use permuted lds layout. so, use padding for these two. - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 192, 64, 8, 8, 32, 32, 4, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, -#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> // clang-format on >; +// instances not working on gfx950 +template +using device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances_part2 = std::tuple< + // clang-format off + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + // AGPR Spill + // DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + // AGPR Spill when use permuted lds layout. so, use padding for these two. + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 192, 64, 8, 8, 32, 32, 4, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; template using device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_instances = std::tuple< -// clang-format off + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) // Latency friendly DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, @@ -90,13 +88,17 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_instances = std::tup DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, -#if !defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, -#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> // clang-format on >; +// instances not working on gfx950 +template +using device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_instances_part2 = std::tuple< + // clang-format off + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp index cbfaf9aaa0..d30ce71d6a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -15,6 +16,13 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance { add_device_operation_instances( instances, device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instance.cpp index 3b0ccf9d83..b908cacb4b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -16,6 +17,13 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instanc add_device_operation_instances( instances, device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instance.cpp index 3800b7955d..301395c4e6 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -16,6 +17,14 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instan add_device_operation_instances( instances, device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp index 9e78df7b5e..ddaa504345 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -16,6 +17,14 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_insta add_device_operation_instances( instances, device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instance.cpp index 21f118ea60..b08137c17f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -16,6 +17,14 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instan add_device_operation_instances( instances, device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp index 3e841a8498..ed553738b1 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -16,6 +17,14 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_insta add_device_operation_instances( instances, device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp index e3f3afff46..940da94e70 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp @@ -34,22 +34,17 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template using device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances = std::tuple< -// clang-format off + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 32, 32, 4, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, -#if !defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, -#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, @@ -58,16 +53,21 @@ using device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances = std::tuple DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 2, 2, 32, 32, 2, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> // clang-format on >; +// instances not working on gfx950 +template +using device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances_part2 = std::tuple< + // clang-format off + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; template using device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_instances = std::tuple< -// clang-format off + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) // Latency friendly DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 4, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 8, 4, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp index 41d6481c9f..8e5ce4ded4 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -15,6 +16,13 @@ void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_default_instances( { add_device_operation_instances( instances, device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp index de41821d98..1922e2b9ab 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -15,6 +16,13 @@ void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_kpadding_instances( { add_device_operation_instances( instances, device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp index cdde9fa43c..8cfcb85aa0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -15,6 +16,13 @@ void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance { add_device_operation_instances( instances, device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp index 04237cc62c..e502314388 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -15,6 +16,13 @@ void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances { add_device_operation_instances( instances, device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp index e39c9a63b9..d83014d5e8 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp @@ -34,13 +34,11 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances = std::tuple< -// clang-format off + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) // Compute friendly DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, @@ -57,13 +55,9 @@ using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances = std::tuple DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 16, 16, 8, 8, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, -#if !defined(CK_USE_AMD_MFMA_GFX950) // AGPR Spill // DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, // AGPR Spill when use permuted lds layout. so, use padding for these two. - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, -#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, @@ -71,16 +65,22 @@ using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances = std::tuple DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> // clang-format on >; +// instances not working on gfx950 +template +using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances_part2 = std::tuple< + // clang-format off + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; template using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances = std::tuple< -// clang-format off + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) // Latency friendly DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, @@ -104,10 +104,6 @@ using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances = std::tuple< DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, -#if !defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, -#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 64, 8, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, @@ -116,6 +112,14 @@ using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances = std::tuple< DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 2, 2, 32, 32, 1, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> // clang-format on >; +// instances not working on gfx950 +template +using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances_part2 = std::tuple< + // clang-format off + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp index 77addd6ad2..ecae32aa18 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -15,6 +16,13 @@ void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_default_instances( { add_device_operation_instances( instances, device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp index 4fb034d3b0..f65f47a08a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -15,6 +16,13 @@ void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instances( { add_device_operation_instances( instances, device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp index 7f7ec14ba6..3d16b5c282 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -16,6 +17,14 @@ void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_default_instances add_device_operation_instances( instances, device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp index 32634a6129..b834022c91 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -16,6 +17,14 @@ void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance add_device_operation_instances( instances, device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp index 3062add942..50063c81ad 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -16,6 +17,14 @@ void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instances add_device_operation_instances( instances, device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp index ede5e4c428..99469f5436 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -16,6 +17,14 @@ void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance add_device_operation_instances( instances, device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp index 43dc6be076..ff13de1d6a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp @@ -35,36 +35,36 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template using device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances = std::tuple< -// clang-format off + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, // Disable due to test failure // DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, -#if !defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 4, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, -#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> // clang-format on >; +// instances not working on gfx950 +template +using device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances_part2 = std::tuple< + // clang-format off + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 4, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; template using device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_instances = std::tuple< -// clang-format off + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) // Latency friendly DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instance.cpp index f3e96e83f8..0a86fd3f9e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -15,6 +16,13 @@ void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instances( { add_device_operation_instances( instances, device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instance.cpp index d73b75fcdc..e88928efbc 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -15,6 +16,13 @@ void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instances( { add_device_operation_instances( instances, device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instance.cpp index 19894a4402..05b3e3762d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -15,6 +16,13 @@ void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instances { add_device_operation_instances( instances, device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instance.cpp index f1123e5715..83db356dc8 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -15,6 +16,13 @@ void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instances( { add_device_operation_instances( instances, device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp index 9bdb2f51c2..bb10da37f4 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp @@ -35,33 +35,33 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template using device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_instances = std::tuple< -// clang-format off + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) // Compute friendly DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, -#if !defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, -#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> // clang-format on >; +// instances not working on gfx950 +template +using device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_instances_part2 = std::tuple< + // clang-format off + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; template using device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_instances = std::tuple< -// clang-format off + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) // Latency friendly DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instance.cpp index b6d916f26c..bae0d6f7c6 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -15,6 +16,13 @@ void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instances( { add_device_operation_instances( instances, device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instance.cpp index e72a748e96..906368f54d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -15,6 +16,13 @@ void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instances( { add_device_operation_instances( instances, device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn.hpp index 616133e1ba..315a6446c1 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn.hpp @@ -47,11 +47,6 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_instances = std::tuple< DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 128, 16, 8, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, -#if !defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 16, 4, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 128, 16, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, -#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 16, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F8>, @@ -61,6 +56,23 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_instances = std::tuple< #endif // clang-format on >; +// instances not working on gfx950 +template +using device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_instances_part2 = std::tuple< +// clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) + //Only enable these instances on gfx94x + // Compute friendly + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 16, 4, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 128, 16, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> +#endif + // clang-format on + >; template using device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_default_instance.cpp index 96f171e066..e2f843ac58 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_default_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -15,6 +16,13 @@ void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_default_instances( { add_device_operation_instances( instances, device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_kpadding_instance.cpp index 4965fe51c6..3778848db3 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_kpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -15,6 +16,13 @@ void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_kpadding_instances( { add_device_operation_instances( instances, device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instance.cpp index d325c47d8a..2e6224c139 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -15,6 +16,13 @@ void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instances( { add_device_operation_instances( instances, device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp index 6388c13444..27d7933477 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp @@ -42,16 +42,7 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances = std::tuple< //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | #if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) // Compute friendly - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 16, 16, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, -#if !defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, -#endif // !defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 128, 16, 16, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 192, 128, 16, 16, 32, 32, 4, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F8>, @@ -63,6 +54,26 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances = std::tuple< #endif // clang-format on >; +// instances not working on gfx950 +template +using device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances_part2 = std::tuple< +// clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) + // Compute friendly + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 16, 16, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 128, 16, 16, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 192, 128, 16, 16, 32, 32, 4, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> +#endif + // clang-format on + >; template using device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp index 48581a7344..d6c9809020 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -15,6 +16,13 @@ void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances( { add_device_operation_instances( instances, device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp index 09bc544deb..fc6ad01742 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -15,6 +16,13 @@ void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances( { add_device_operation_instances( instances, device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances_part2{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp index 7fb690c8b2..1cc0b8c76e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp @@ -50,9 +50,6 @@ using device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances = DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 32, 32, 4, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, -#if !defined(CK_USE_AMD_MFMA_GFX950) - DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, -#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 160, 64, 8, 8, 16, 16, 8, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 32, 32, 1, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, @@ -62,6 +59,15 @@ using device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances = DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> // clang-format on >; +// instances not working on gfx950 +template , + typename DsDataType = ck::Tuple<>> +using device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances_part2 = std::tuple< + // clang-format off + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; template {}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances_part2< + GemmDefault>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp index 68c6ce89ab..6cadb7deb8 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp @@ -39,20 +39,15 @@ template , typename DsDataType = ck::Tuple<>> using device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_instances = std::tuple< -// clang-format off + // clang-format off //#########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, -#if !defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, -#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, @@ -60,19 +55,26 @@ using device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_instances = DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> // clang-format on >; +// instances not working on gfx950 +template , + typename DsDataType = ck::Tuple<>> +using device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_instances_part2 = std::tuple< + // clang-format off + DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; template , typename DsDataType = ck::Tuple<>> using device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_instances = std::tuple< -// clang-format off + // clang-format off //#########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) // Latency friendly DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 4, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp index bd5cfc5f20..4be2db2531 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -24,6 +25,14 @@ void add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_default_i add_device_operation_instances( instances, device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_instances_part2< + GemmDefault>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instance.cpp index f5c2c95a03..93fa56c1a5 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -24,6 +25,14 @@ void add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_ add_device_operation_instances( instances, device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_instances_part2< + GemmKPadding>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp index 978a027048..cd032e2a23 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -24,6 +25,14 @@ void add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_mnkpaddin add_device_operation_instances( instances, device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_instances_part2< + GemmMNKPadding>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instance.cpp index c8296e9bfc..3ff8d8e172 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -24,6 +25,14 @@ void add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding add_device_operation_instances( instances, device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_instances_part2< + GemmMNPadding>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_i8_bf16/device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_i8_bf16/device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn.hpp index 4e1e5567d5..c753a91d0a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_i8_bf16/device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_i8_bf16/device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn.hpp @@ -40,15 +40,11 @@ template , typename DsDataType = ck::Tuple<>> using device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_instances = std::tuple< -// clang-format off + // clang-format off //#########################| ALayout| BLayout| DsLayout| CLayout|AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) - //DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 128, 16, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v2>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, @@ -59,20 +55,30 @@ using device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_instances = st DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> // clang-format on >; +// double rate mfma instances on gfx950 +template , + typename DsDataType = ck::Tuple<>> +using device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_instances_2x = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| DsLayout| CLayout|AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 128, 16, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v2> + // clang-format on + >; template , typename DsDataType = ck::Tuple<>> using device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_instances = std::tuple< -// clang-format off + // clang-format off //#########################| ALayout| BLayout| DsLayout| CLayout|AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) - //DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 128, 16, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> -#else // Latency friendly DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, @@ -85,7 +91,6 @@ using device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_instances = std DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> -#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_i8_bf16/device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_i8_bf16/device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn_comp_default_instance.cpp index 66db3ddb3d..0ad6c2c75a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_i8_bf16/device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn_comp_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_i8_bf16/device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn_comp_default_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -24,6 +25,14 @@ void add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_default_ins add_device_operation_instances( instances, device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_instances{}); + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_instances_2x< + GemmDefault>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_i8_bf16/device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_i8_bf16/device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instance.cpp index 7d3c832ee5..b5f784b9aa 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_i8_bf16/device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_i8_bf16/device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -24,6 +25,14 @@ void add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_kpadding_in add_device_operation_instances( instances, device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_instances{}); + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_instances_2x< + GemmKPadding>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_i8_bf16/device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_i8_bf16/device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp index 2759b878a3..1c48fea463 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_i8_bf16/device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_i8_bf16/device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -24,6 +25,14 @@ void add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_ add_device_operation_instances( instances, device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_instances{}); + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_instances_2x< + GemmMNKPadding>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_i8_bf16/device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_i8_bf16/device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instance.cpp index 04a0229300..e500ebde87 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_i8_bf16/device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_i8_bf16/device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -24,6 +25,14 @@ void add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_i add_device_operation_instances( instances, device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_instances{}); + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_instances_2x< + GemmMNPadding>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp index 928c325ab7..d4fed3b561 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp @@ -39,20 +39,15 @@ template , typename DsDataType = ck::Tuple<>> using device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_instances = std::tuple< -// clang-format off + // clang-format off //#########################| ALayout| BLayout| DsLayout| CLayout|AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, -#if !defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, -#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, @@ -60,19 +55,26 @@ using device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_instances = std DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> // clang-format on >; +// instances not working on gfx950 +template , + typename DsDataType = ck::Tuple<>> +using device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_instances_part2 = std::tuple< + // clang-format off + DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; template , typename DsDataType = ck::Tuple<>> using device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_instances = std::tuple< -// clang-format off + // clang-format off //#########################| ALayout| BLayout| DsLayout| CLayout|AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) // Latency friendly DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 4, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp index 53d2f8f0c3..00d88467e8 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -24,6 +25,14 @@ void add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_default_inst add_device_operation_instances( instances, device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_instances_part2< + GemmDefault>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp index 11f696c509..a7e60a6d40 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -24,6 +25,14 @@ void add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_kpadding_ins add_device_operation_instances( instances, device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_instances_part2< + GemmKPadding>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp index 34e129bdcb..7ce358c695 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -24,6 +25,14 @@ void add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_mnkpadding_i add_device_operation_instances( instances, device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_instances_part2< + GemmMNKPadding>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp index 7a580b6716..e64c145987 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -24,6 +25,14 @@ void add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_mnpadding_in add_device_operation_instances( instances, device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_instances_part2< + GemmMNPadding>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn.hpp index 2a02995827..5353fe16b5 100755 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn.hpp @@ -48,14 +48,18 @@ using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_instances = DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, -#if !defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 2, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, -#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> // clang-format on >; +// instances not working on gfx950 +template +using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_instances_part2 = std::tuple< + // clang-format off + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 2, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; template using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp old mode 100755 new mode 100644 index 8b2bfb5d26..9d95e6ea5a --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -22,6 +23,14 @@ void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_default_ add_device_operation_instances( instances, device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_instances_part2< + GemmDefault>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instance.cpp old mode 100755 new mode 100644 index a7c33ffdc4..3e54a4bb5b --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -22,6 +23,14 @@ void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_kpadding add_device_operation_instances( instances, device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_instances_part2< + GemmKPadding>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp old mode 100755 new mode 100644 index adc2f23d40..544723ef6a --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -22,6 +23,14 @@ void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_mnkpaddi add_device_operation_instances( instances, device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_instances_part2< + GemmMNKPadding>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instance.cpp old mode 100755 new mode 100644 index 0336f64665..e6959c0945 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -22,6 +23,14 @@ void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_mnpaddin add_device_operation_instances( instances, device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_instances_part2< + GemmMNPadding>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp index 425c2c0391..209d8f644e 100755 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp @@ -51,11 +51,7 @@ using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_instances = // AGPR Spill // DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, // AGPR Spill when use permuted lds layout. so, use padding for these two. -#if !defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - -#endif // !defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, @@ -65,6 +61,13 @@ using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_instances = // clang-format on >; +// instances not working on gfx950 +template +using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_instances_part2 = std::tuple< + // clang-format off + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; template using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp old mode 100755 new mode 100644 index e192bf14c5..924ed814ce --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -22,6 +23,14 @@ void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_default_ add_device_operation_instances( instances, device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_instances_part2< + GemmDefault>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instance.cpp old mode 100755 new mode 100644 index d58ec3eb33..daed25130f --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -22,6 +23,14 @@ void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_kpadding add_device_operation_instances( instances, device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_instances_part2< + GemmKPadding>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp old mode 100755 new mode 100644 index aa193417d9..282cea7563 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp @@ -34,14 +34,11 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances = std::tuple< -// clang-format off + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(__gfx950__) - DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 128, 16, 2, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v2> -#else DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 32, 32, 4, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, @@ -57,19 +54,16 @@ using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances = st DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 2, 2, 32, 32, 2, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> -#endif // defined(__gfx950__) // clang-format on >; template using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances = std::tuple< -// clang-format off + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(__gfx950__) -#else // Latency friendly DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, @@ -100,7 +94,6 @@ using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances = std DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 4, 4, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> -#endif // defined(__gfx950__) // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp old mode 100755 new mode 100644 index a685c4f252..7335a9851f --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp @@ -34,13 +34,11 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances = std::tuple< -// clang-format off + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) // Compute friendly DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, @@ -58,9 +56,6 @@ using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances = st DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 16, 16, 8, 8, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, -#if !defined(CK_USE_AMD_MFMA_GFX950) - DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, -#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, @@ -69,16 +64,21 @@ using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances = st DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> // clang-format on >; +// instances not working on gfx950 +template +using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances_part2 = std::tuple< + // clang-format off + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; template using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances = std::tuple< -// clang-format off + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) // Latency friendly DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp index 546f909b3c..4f048b6525 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -22,6 +23,14 @@ void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_ins add_device_operation_instances( instances, device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances_part2< + GemmDefault>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp index d91de96be3..c547487983 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -22,6 +23,14 @@ void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_in add_device_operation_instances( instances, device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances_part2< + GemmKPadding>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp index c70678b449..28afafa26d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -22,6 +23,14 @@ void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_ add_device_operation_instances( instances, device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances_part2< + GemmMNKPadding>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp index 5410a0cc25..e9bbc4d732 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -22,6 +23,14 @@ void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_i add_device_operation_instances( instances, device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances_part2< + GemmMNPadding>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_comp_instance.cpp index 65e233ce08..a5bf286ca7 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_comp_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_comp_instance.cpp @@ -3,6 +3,7 @@ #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -31,6 +32,30 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_comp_instances( Empty_Tuple, NGKHW, ConvFwdDefault>{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2, + NGCHW, + GKYXC, + Empty_Tuple, + NGKHW, + ConvFwdDefault>{}); + } + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2, + NGCHW, + GKYXC, + Empty_Tuple, + NGKHW, + ConvFwdDefault>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instance.cpp index 36c8e3cb13..ea6ba831f1 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instance.cpp @@ -3,6 +3,7 @@ #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -31,6 +32,30 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances( Empty_Tuple, NGKHW, ConvFwdDefault>{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_part2<2, + NGCHW, + GKYXC, + Empty_Tuple, + NGKHW, + ConvFwdDefault>{}); + } + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_2x<2, + NGCHW, + GKYXC, + Empty_Tuple, + NGKHW, + ConvFwdDefault>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_comp_instance.cpp index d98b89c556..8f0a5ca425 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_comp_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_comp_instance.cpp @@ -3,6 +3,7 @@ #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -31,6 +32,30 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_comp_instances( Empty_Tuple, NGKHW, ConvFwdDefault>{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_int8_comp_instances_part2<2, + NGCHW, + GKYXC, + Empty_Tuple, + NGKHW, + ConvFwdDefault>{}); + } + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_int8_comp_instances_2x<2, + NGCHW, + GKYXC, + Empty_Tuple, + NGKHW, + ConvFwdDefault>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp index 9f06347350..a344e35c8d 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp @@ -3,6 +3,7 @@ #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -56,6 +57,84 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances( Empty_Tuple, NHWGK, ConvFwdOddC>{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); + } + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp index 9b1c7ef65e..30a8b60bfc 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp @@ -3,6 +3,7 @@ #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -56,6 +57,84 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances( Empty_Tuple, NHWGK, ConvFwdOddC>{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_part2<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_part2<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_part2<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_part2<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); + } + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_2x<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_2x<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_2x<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_2x<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp index 78c2257b9a..6acbb7475c 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp @@ -3,6 +3,7 @@ #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -56,6 +57,84 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instances( Empty_Tuple, NHWGK, ConvFwdOddC>{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_int8_comp_instances_part2<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_int8_comp_instances_part2<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_int8_comp_instances_part2<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_int8_comp_instances_part2<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); + } + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_int8_comp_instances_2x<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_int8_comp_instances_2x<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_int8_comp_instances_2x<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_int8_comp_instances_2x<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp index 6fa4bc6e46..5e78cf41d8 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -3,6 +3,7 @@ #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -23,23 +24,46 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_inst PassThrough, PassThrough>>>& instances) { - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdDefault>{}); + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances_2x<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwd3x3>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances_2x<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd3x3>{}); + } + else + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd3x3>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp index 9fa56f48c7..fb494acd93 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -3,6 +3,7 @@ #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -23,23 +24,46 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_insta PassThrough, PassThrough>>>& instances) { - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_merged_groups_f16_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdDefault>{}); + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f16_instances_2x<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_merged_groups_f16_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwd3x3>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f16_instances_2x<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd3x3>{}); + } + else + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd3x3>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp index efc4640603..a94f687ef8 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp @@ -3,6 +3,7 @@ #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -46,6 +47,62 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances( Empty_Tuple, NDHWGK, ConvFwd1x1S1P0>{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); + } + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp index 3f3cd4b7d2..0c63345e7f 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp @@ -3,6 +3,7 @@ #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -46,6 +47,62 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances( Empty_Tuple, NDHWGK, ConvFwd1x1S1P0>{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_part2<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_part2<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_part2<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); + } + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_2x<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_2x<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_2x<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp index 0bd53706be..20f19bd774 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp @@ -30,13 +30,11 @@ using Empty_Tuple = ck::Tuple<>; using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; using device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_tile_instances = std::tuple< -// clang-format off + // clang-format off //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(CK_USE_AMD_MFMA_GFX950) -#endif // defined(CK_USE_AMD_MFMA_GFX950) DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F8, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1>, DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F8, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1>, DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F8, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1>, diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp index baf04cf12e..839d3559f7 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp @@ -39,19 +39,31 @@ static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; +// double rate mfma instances on gfx950 +template +using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances_2x = std::tuple< + // clang-format off + //###########################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + template using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances = std::tuple< -// clang-format off + // clang-format off //###########################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, -#endif // defined(CK_USE_AMD_MFMA_GFX950) // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, @@ -70,13 +82,11 @@ template using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances = std::tuple< -// clang-format off + // clang-format off //###########################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_default_instance.cpp index 6848774431..220cb8031f 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_default_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -27,6 +28,17 @@ void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_d ck::Tuple, Multiply, GemmDefault>{}); + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances_2x< + ck::Tuple, + ck::Tuple, + Multiply, + GemmDefault>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instance.cpp index bb2ea76aa4..6604078013 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -27,6 +28,17 @@ void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_k ck::Tuple, Multiply, GemmKPadding>{}); + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances_2x< + ck::Tuple, + ck::Tuple, + Multiply, + GemmKPadding>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp index 7439433f8a..157b45d713 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -27,6 +28,17 @@ void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_m ck::Tuple, Multiply, GemmMNKPadding>{}); + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances_2x< + ck::Tuple, + ck::Tuple, + Multiply, + GemmMNKPadding>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instance.cpp index b3afed0fd7..ca7e774ccd 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instance.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -27,6 +28,17 @@ void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_m ck::Tuple, Multiply, GemmMNPadding>{}); + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances_2x< + ck::Tuple, + ck::Tuple, + Multiply, + GemmMNPadding>{}); + } } } // namespace instance