From 1b409ffe5b4c5059f9bbd9571de5af1fabc9d972 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Tue, 28 Feb 2023 20:20:46 +0000 Subject: [PATCH] fix mfma_int8 on MI300 --- include/ck/ck.hpp | 12 +- .../device_gemm_xdl_waveletmodel_cshuffle.hpp | 6 +- ...gemm_softmax_gemm_permute_xdl_cshuffle.hpp | 6 +- ...tk_contraction_multiple_d_xdl_cshuffle.hpp | 6 +- ...ed_contraction_multiple_d_xdl_cshuffle.hpp | 6 +- .../device_batched_gemm_e_permute_xdl.hpp | 3 +- .../device_batched_gemm_gemm_xdl_cshuffle.hpp | 6 +- .../impl/device_batched_gemm_multi_d_xdl.hpp | 6 +- ...ultiple_d_gemm_multiple_d_xdl_cshuffle.hpp | 6 +- ...evice_batched_gemm_reduce_xdl_cshuffle.hpp | 3 +- ...gemm_softmax_gemm_permute_xdl_cshuffle.hpp | 6 +- ...batched_gemm_softmax_gemm_xdl_cshuffle.hpp | 6 +- .../device/impl/device_batched_gemm_xdl.hpp | 3 +- ...ce_contraction_multiple_d_xdl_cshuffle.hpp | 6 +- ...evice_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp | 3 +- .../impl/device_gemm_bias_e_permute_xdl.hpp | 6 +- ...gemm_multiple_d_layernorm_xdl_cshuffle.hpp | 6 +- ...emm_multiple_d_multiple_r_xdl_cshuffle.hpp | 6 +- .../device_gemm_multiple_d_xdl_cshuffle.hpp | 6 +- .../device/impl/device_gemm_xdl_cshuffle.hpp | 3 +- .../device_gemm_xdl_layernorm_cshuffle.hpp | 3 +- ...ed_contraction_multiple_d_xdl_cshuffle.hpp | 6 +- ...nv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp | 3 +- ...bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp | 3 +- ...fwd_multiple_d_multiple_r_xdl_cshuffle.hpp | 3 +- ...ouped_conv_fwd_multiple_d_xdl_cshuffle.hpp | 3 +- .../device/impl/device_grouped_gemm_xdl.hpp | 3 +- ...e_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp | 3 +- .../gridwise_gemm_reduce_xdl_cshuffle_v1.hpp | 3 +- .../grid/gridwise_gemm_xdl_cshuffle_v1.hpp | 3 +- ...ridwise_gemm_xdl_layernorm_cshuffle_v1.hpp | 3 +- .../grid/gridwise_gemm_xdlops_bwd_weight.hpp | 3 +- .../gridwise_gemm_xdlops_skip_b_lds_v1.hpp | 3 +- .../gpu/grid/gridwise_gemm_xdlops_v2r3.hpp | 3 +- .../gpu/grid/gridwise_gemm_xdlops_v2r4.hpp | 3 +- .../gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp | 3 +- .../gpu/grid/gridwise_gemm_xdlops_v3r1.hpp | 3 +- .../gpu/grid/gridwise_gemm_xdlops_v3r2.hpp | 3 +- .../gpu/grid/gridwise_gemm_xdlops_v3r3.hpp | 3 +- .../tensor_operation/gpu/warp/xdlops_gemm.hpp | 113 +++++++++--------- include/ck/utility/amd_xdlops.hpp | 59 +++++---- include/ck/utility/data_type.hpp | 2 + script/cmake-ck-dev.sh | 4 +- 43 files changed, 214 insertions(+), 135 deletions(-) diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 6213015436..d286bb7478 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -51,8 +51,8 @@ #ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing #elif defined(__gfx803__) || defined(__gfx900__) // for GPU code #define CK_USE_AMD_V_MAC_F32 -#elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx1030__) || defined(__gfx940__) // for GPU code +#elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__) || \ + defined(__gfx940__) // for GPU code #define CK_USE_AMD_V_FMAC_F32 #define CK_USE_AMD_V_DOT2_F32_F16 #define CK_USE_AMD_V_DOT4_I32_I8 @@ -65,10 +65,14 @@ #define CK_USE_AMD_MFMA #endif -#if (defined(__gfx90a__) || defined(__gfx940__)) +#if(defined(__gfx90a__) || defined(__gfx940__)) #define CK_USE_AMD_MFMA_BF16_1K_OP #endif +#if defined(__gfx940__) +#define CK_USE_AMD_MFMA_GFX940 +#endif + // WMMA instruction #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_USE_AMD_WMMA @@ -94,7 +98,7 @@ #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0 #endif -#if (defined(__gfx90a__) || defined(__gfx940__)) // for GPU code +#if(defined(__gfx90a__) || defined(__gfx940__)) // for GPU code #define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1 #else #define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0 diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_waveletmodel_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_waveletmodel_cshuffle.hpp index bb8a10e955..af38f14254 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_waveletmodel_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_waveletmodel_cshuffle.hpp @@ -47,7 +47,8 @@ __global__ void e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -416,7 +417,8 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm(p_a_grid, @@ -581,7 +582,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940")) + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || + ck::get_device_name() == "gfx940")) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp index 7e5b68207f..69c842137e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp @@ -55,7 +55,8 @@ __global__ void const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / num_batches); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_e_permute_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_e_permute_xdl.hpp index 497e79ef1c..a7bce886f4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_e_permute_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_e_permute_xdl.hpp @@ -51,7 +51,8 @@ __global__ void e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -456,7 +457,8 @@ struct DeviceGemmBiasEPermute_Xdl : public DeviceGemmBiasCPermute( @@ -851,7 +852,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940")) + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || + ck::get_device_name() == "gfx940")) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp index 032597d27d..3863704cf0 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -60,7 +60,8 @@ __global__ void const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -554,7 +555,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940")) + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || + ck::get_device_name() == "gfx940")) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp index 62eaef4467..0c845ab5b3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp @@ -51,7 +51,8 @@ __global__ void e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -490,7 +491,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp index 6aaaf69ccd..a3f5324713 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -54,7 +54,8 @@ __global__ void const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp index 59d40801d1..1213cdc263 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -44,7 +44,8 @@ __global__ void c_grid_desc_mblock_mperblock_nblock_nperblock, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp index 20982471e8..2d4ebe7076 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp @@ -57,7 +57,8 @@ __global__ void const C0GridDescriptor_NBlock_NPerBlock c0_grid_desc_nblock_nperblock, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; // TODO ANT: separate into MMA + Epilogue diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index f82e739dea..65401fda9e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -165,7 +165,8 @@ __global__ void const CElementwiseOperation c_element_op, const CBlockClusterAdaptor c_block_cluster_adaptor) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__)) constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp index 1d507ed14c..8d86f3c1d7 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp @@ -44,7 +44,8 @@ __global__ void const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp index f397579a3a..73d7088bc8 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -43,7 +43,8 @@ __global__ void const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp index f3dca72795..55f465a037 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp @@ -42,7 +42,8 @@ __global__ void const CElementwiseOperation c_element_op, const CBlockClusterAdaptor c_block_cluster_adaptor) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__)) constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp index 4f651a33fb..f0ce2e3bdb 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp @@ -44,7 +44,8 @@ __global__ void const CElementwiseOperation c_element_op, const CBlockClusterAdaptor c_block_cluster_adaptor) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__)) constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp index 6a54f66df7..8259927fec 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp @@ -46,7 +46,8 @@ __global__ void const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp index b273b0e614..5d5fdae170 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp @@ -49,7 +49,8 @@ __global__ void const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp index 4eec0004f8..dc83f8e984 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp @@ -53,7 +53,8 @@ __global__ void const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run( diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index dcca5ce37f..319487bc05 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -9,7 +9,6 @@ namespace ck { -#if (defined(__gfx908__) || defined(__gfx90a__)) enum struct MfmaInstr { mfma_f32_32x32x1xf32 = 0, @@ -28,30 +27,10 @@ enum struct MfmaInstr mfma_f32_16x16x8bf16, mfma_i32_32x32x8i8, mfma_i32_16x16x16i8, - mfma_f64_16x16x4f64 -}; -#elif (defined(__gfx940__)) -enum struct MfmaInstr -{ - mfma_f32_32x32x1xf32 = 0, - mfma_f32_16x16x1xf32, - mfma_f32_4x4x1xf32, - mfma_f32_32x32x2xf32, - mfma_f32_16x16x4xf32, - mfma_f32_32x32x4f16, - mfma_f32_16x16x4f16, - mfma_f32_4x4x4f16, - mfma_f32_32x32x8f16, - mfma_f32_16x16x16f16, - mfma_f32_32x32x8bf16_1k, - mfma_f32_16x16x16bf16_1k, - mfma_f32_32x32x4bf16, - mfma_f32_16x16x8bf16, mfma_i32_32x32x16i8, - mfma_i32_16x16x16i8, + mfma_i32_16x16x32i8, mfma_f64_16x16x4f64 }; -#endif template struct mfma_type; @@ -365,7 +344,6 @@ struct mfma_type } }; -#if (defined(__gfx908__) || defined(__gfx90a__)) template <> struct mfma_type { @@ -387,29 +365,6 @@ struct mfma_type intrin_mfma_i32_32x32x8i8::Run(a, b, reg_c); } }; -#elif (defined(__gfx940__)) -template <> -struct mfma_type -{ - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_per_blk = 4; - static constexpr index_t num_regs_per_blk = 16; - static constexpr index_t num_threads_per_blk = 32; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = 2; - static constexpr index_t num_output_blks = 1; - static constexpr index_t m_per_blk = 32; - static constexpr index_t n_per_blk = 32; - static constexpr index_t k_per_blk = 4; - static constexpr bool is_k_reduction = true; - - template - __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const - { - intrin_mfma_i32_32x32x16i8::Run(a, b, reg_c); - } -}; -#endif template <> struct mfma_type @@ -433,6 +388,50 @@ struct mfma_type } }; +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_i32_32x32x16i8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_i32_16x16x32i8::Run(a, b, reg_c); + } +}; + template <> struct mfma_type { @@ -571,25 +570,29 @@ struct MfmaSelector #endif } -#if (defined(__gfx908__) || defined(__gfx90a__)) - template <> - static constexpr auto GetMfma() - { - return MfmaInstr::mfma_i32_32x32x8i8; - } -#elif (defined(__gfx940__)) +#if defined(CK_USE_AMD_MFMA_GFX940) template <> static constexpr auto GetMfma() { return MfmaInstr::mfma_i32_32x32x16i8; } -#endif - + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_i32_16x16x32i8; + } +#else + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_i32_32x32x8i8; + } template <> static constexpr auto GetMfma() { return MfmaInstr::mfma_i32_16x16x16i8; } +#endif static constexpr auto selected_mfma = mfma_type()>{}; diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index bc9676f1f7..a742496fc1 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -259,7 +259,6 @@ struct intrin_mfma_f32_16x16x8bf16<16, 16> } }; -#if (defined(__gfx908__) || defined(__gfx90a__)) template struct intrin_mfma_i32_32x32x8i8; @@ -278,26 +277,6 @@ struct intrin_mfma_i32_32x32x8i8<32, 32> 0); } }; -#elif (defined(__gfx940__)) -template -struct intrin_mfma_i32_32x32x16i8; - -template <> -struct intrin_mfma_i32_32x32x16i8<32, 32> -{ - template - __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c) - { - reg_c.template AsType()(Number<0>{}) = - __builtin_amdgcn_mfma_i32_32x32x16_i8(bit_cast(reg_a), - bit_cast(reg_b), - reg_c.template AsType()[Number<0>{}], - 0, - 0, - 0); - } -}; -#endif template struct intrin_mfma_i32_16x16x16i8; @@ -318,6 +297,44 @@ struct intrin_mfma_i32_16x16x16i8<16, 16> } }; +template +struct intrin_mfma_i32_32x32x16i8; + +template <> +struct intrin_mfma_i32_32x32x16i8<32, 32> +{ + template + __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_i32_32x32x16_i8(bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); + } +}; + +template +struct intrin_mfma_i32_16x16x32i8; + +template <> +struct intrin_mfma_i32_16x16x32i8<16, 16> +{ + template + __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_i32_16x16x32i8(bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); + } +}; + template struct intrin_mfma_f64_16x16x4f64; diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 40ee8b617e..e97b932ab4 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -898,6 +898,8 @@ struct vector_type } }; +using int64_t = long; + // fp64 using double2_t = typename vector_type::type; using double4_t = typename vector_type::type; diff --git a/script/cmake-ck-dev.sh b/script/cmake-ck-dev.sh index 38f6581831..b85889743b 100755 --- a/script/cmake-ck-dev.sh +++ b/script/cmake-ck-dev.sh @@ -10,8 +10,8 @@ cmake -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_CXX_FLAGS="-O3 -ftemplate-backtrace-limit=0 -gline-tables-only -save-temps=$PWD" \ -D CMAKE_BUILD_TYPE=Release \ --D BUILD_DEV=ON \ --D GPU_TARGETS="gfx90a" \ +-D BUILD_DEV=OFF \ +-D GPU_TARGETS="gfx940" \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D USE_BITINT_EXTENSION_INT4=OFF \ ${MY_PROJECT_SOURCE}