From ca0015bf3983d3fd4d45768e789514927543ebe2 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Mon, 20 May 2024 08:34:45 -0700 Subject: [PATCH] aggregate device macros in ck_tile config header (#1297) [ROCm/composable_kernel commit: 06b891c5c2e75f7a2c4cf71c5597fdc16c236d50] --- ...ce_contraction_multiple_d_xdl_cshuffle.hpp | 3 +- ...ce_grouped_gemm_multi_abd_xdl_fixed_nk.hpp | 3 +- .../gpu/grid/gridwise_fpAintB_gemm_wmma.hpp | 5 ++-- .../grid/gridwise_gemm_xdl_cshuffle_v3.hpp | 10 +++---- ...ridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp | 10 +++---- ...ise_gemm_xdlops_splitk_lds_direct_load.hpp | 5 ++-- include/ck_tile/core/config.hpp | 30 +++++++++++++------ include/ck_tile/core/numeric/float8.hpp | 20 ++++++------- .../ck_tile/core/tensor/tile_elementwise.hpp | 2 +- .../warp/warp_gemm_attribute_mfma_impl.hpp | 24 +++++++-------- 10 files changed, 56 insertions(+), 56 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp index 4cc60f2836..9d5b74be6c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp @@ -53,8 +53,7 @@ __global__ void e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp index bf8788a3b2..1f60818e39 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp @@ -45,8 +45,7 @@ __global__ void const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t KBatch = 1; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp index 67e211ef8d..499eb7eb01 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp @@ -50,8 +50,7 @@ __global__ void const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ - defined(__gfx1102__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size]; GridwiseGemm::template Run(p_a_grid, @@ -80,7 +79,7 @@ __global__ void ignore = b_element_op; ignore = c_element_op; ignore = block_2_ctile_map; -#endif // end of if (defined(__gfx1100__)) +#endif // end of if (defined(__gfx11__)) } // Assume B is Col-Major diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index 82c05937c7..dc45407e56 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -34,8 +34,7 @@ __global__ void // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); @@ -48,7 +47,7 @@ __global__ void karg); #else ignore = karg; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } template ( @@ -49,7 +48,7 @@ __global__ void karg.c_element_op); #else ignore = karg; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } template {}(reinterpret_cast(&x), x); -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) float max_fp8 = 240.0f; x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); union @@ -500,7 +500,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_sr_raw(float x) { constexpr int seed = 42; uint32_t rng = prand_generator_t{}(reinterpret_cast(&x), x); -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) union { float fval; @@ -526,7 +526,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_sr_raw(float x) CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_rtn_raw(float x) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) float max_fp8 = 240.0f; x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); union @@ -554,7 +554,7 @@ CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_rtn_raw(float x) } CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_rtn_raw(float x) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) union { float fval; @@ -598,7 +598,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_raw(float x, constant) CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) float fval; uint32_t i32val = static_cast(x); fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0); @@ -612,7 +612,7 @@ CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x) CK_TILE_HOST_DEVICE float bf8_to_float_raw(bf8_raw_t x) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) float fval; uint32_t i32val = static_cast(x); fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0); @@ -656,7 +656,7 @@ struct numeric_traits { static constexpr int exp = 4; static constexpr int mant = 3; -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) static constexpr int bias = 8; #else static constexpr int bias = 7; @@ -668,7 +668,7 @@ struct numeric_traits { static constexpr int exp = 5; static constexpr int mant = 2; -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) static constexpr int bias = 16; #else static constexpr int bias = 15; // IEEE diff --git a/include/ck_tile/core/tensor/tile_elementwise.hpp b/include/ck_tile/core/tensor/tile_elementwise.hpp index 90ad94b12b..48762b7225 100644 --- a/include/ck_tile/core/tensor/tile_elementwise.hpp +++ b/include/ck_tile/core/tensor/tile_elementwise.hpp @@ -112,7 +112,7 @@ namespace impl { template CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) // This API is designed to use the _pk_ serious of function constexpr auto in_tile_dstr = InTensor::get_tile_distribution(); diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index cb250516f4..dd164e72ea 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -36,8 +36,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 CK_TILE_DEVICE void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__) +#if defined(__gfx9__) c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0); #else ck_tile::ignore = c_vec; @@ -49,8 +48,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__) +#if defined(__gfx9__) return bit_cast( __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0)); #else @@ -89,8 +87,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 CK_TILE_DEVICE void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__) +#if defined(__gfx9__) c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0); #else ck_tile::ignore = c_vec; @@ -102,8 +99,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__) +#if defined(__gfx9__) return bit_cast( __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); #else @@ -143,7 +139,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 CK_TILE_DEVICE void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx90a__) || defined(__gfx94__) c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); #elif defined(__gfx908__) static_for<0, 2, 1>{}([&](auto k) { @@ -167,7 +163,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx90a__) || defined(__gfx94__) return bit_cast( __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0)); #elif defined(__gfx908__) @@ -220,7 +216,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 CK_TILE_DEVICE void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx90a__) || defined(__gfx94__) c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); #elif defined(__gfx908__) static_for<0, 2, 1>{}([&](auto k) { @@ -244,7 +240,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx90a__) || defined(__gfx94__) return bit_cast( __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); #elif defined(__gfx908__) @@ -299,7 +295,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base CK_TILE_DEVICE void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) if constexpr(std::is_same_v && std::is_same_v) c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); @@ -333,7 +329,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) if constexpr(std::is_same_v && std::is_same_v) return bit_cast(__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0));