diff --git a/CHANGELOG.md b/CHANGELOG.md index ab2076c0d8..0f04935b8d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added benchmarking support for tile engine GEMM. * Added Ping-pong scheduler support for GEMM operation along the K dimension. * Added rotating buffer feature for CK_Tile GEMM. +* Added int8 support for CK_TILE GEMM. ### Optimized diff --git a/example/ck_tile/03_gemm/README.md b/example/ck_tile/03_gemm/README.md index 4c16f13cef..da37159aeb 100644 --- a/example/ck_tile/03_gemm/README.md +++ b/example/ck_tile/03_gemm/README.md @@ -30,7 +30,7 @@ args: -stride_c Tensor C stride (default:0) -v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2) -e Absolute error tolerance (default:1e-5) - -prec data type. fp16/bf16/fp8/bf8 (default:fp16) + -prec data type. fp16/bf16/fp8/bf8/int8 (default:fp16) -warmup number of iterations before benchmark the kernel (default:10) -repeat number of iterations to benchmark the kernel (default:100) -timer gpu:gpu timer, cpu:cpu timer (default:gpu) diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 090a98486e..80c18cdb87 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -212,6 +212,11 @@ int run_gemm_example(int argc, char* argv[]) return run_gemm_example_prec_type( a_layout, b_layout, argc, argv); } + else if(data_type == "i8") + { + return run_gemm_example_prec_type( + a_layout, b_layout, argc, argv); + } else if(data_type == "pk_int4_t") { // TODO: Add support for bhalf_t ADataType diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 101e195903..5f767d56aa 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -1,4 +1,3 @@ - // SPDX-License-Identifier: MIT // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. @@ -16,6 +15,25 @@ #define CK_TILE_PIPELINE_COMPUTE_V4 3 #define CK_TILE_PIPELINE_COMPUTE_V5 4 +// temporary workaround to get k_warp_tile based on PrecType and gfx950 or not +template +constexpr ck_tile::index_t get_k_warp_tile() +{ +#if defined(__gfx950__) + constexpr bool is_8bit_float = + std::is_same_v || std::is_same_v; + if constexpr(M_Warp_Tile == 32) + return is_8bit_float ? 64 : 16; + else + return is_8bit_float ? 128 : 32; +#else + if constexpr(M_Warp_Tile == 32) + return 16; + else + return 32; +#endif +} + struct GemmConfigBase { static constexpr bool kPadM = false; @@ -90,7 +108,7 @@ struct GemmConfigComputeV3 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 16 : 64; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; @@ -109,7 +127,7 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 16 : 64; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; @@ -128,7 +146,7 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 32 : 128; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; @@ -151,7 +169,7 @@ struct GemmConfigComputeV4 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 16 : 64; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = true; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; @@ -170,7 +188,7 @@ struct GemmConfigComputeV4_1 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 16 : 64; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = true; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; @@ -189,7 +207,7 @@ struct GemmConfigComputeV5 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 16 : 64; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; @@ -245,6 +263,15 @@ struct GemmTypeConfig using CDataType = ck_tile::half_t; }; +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::int8_t; + using BDataType = ck_tile::int8_t; + using AccDataType = int32_t; + using CDataType = int32_t; +}; + template struct DataTypeTraits; @@ -260,6 +287,12 @@ struct DataTypeTraits static constexpr const char* name = "fp64"; }; +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int32"; +}; + template <> struct DataTypeTraits { @@ -290,6 +323,12 @@ struct DataTypeTraits static constexpr const char* name = "pk_int4_t"; }; +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int8"; +}; + template struct PipelineTypeTraits; diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 140107bfb4..d3ef974d91 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -297,8 +297,8 @@ int run_gemm_example_with_layouts(int argc, if(init_method == 0) { - ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k); - ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n); + ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); + ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); } else if(init_method == 1) { @@ -415,29 +415,19 @@ int run_gemm_example_with_layouts(int argc, // Restore input for B for gpu reference b_k_n_dev_buf.ToDevice(b_k_n.data()); } + + // memory on host to store gpu reference result ck_tile::HostTensor c_m_n_gpu_ref( ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + // memory on device to store gpu reference result ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes()); + c_m_n_gpu_ref.SetZero(); c_m_n_gpu_buf_ref.SetZero(); - ADataType* d_A; - BDataType* d_B; - CDataType* d_C; - - ck_tile::hip_check_error(hipMalloc(&d_A, a_m_k.get_element_space_size_in_bytes())); - ck_tile::hip_check_error(hipMalloc(&d_B, b_k_n.get_element_space_size_in_bytes())); - ck_tile::hip_check_error( - hipMalloc(&d_C, c_m_n_dev_result.get_element_space_size_in_bytes())); - - ck_tile::hip_check_error(hipMemcpy(d_A, - a_m_k_dev_buf.GetDeviceBuffer(), - a_m_k.get_element_space_size_in_bytes(), - hipMemcpyHostToDevice)); - ck_tile::hip_check_error(hipMemcpy(d_B, - b_k_n_dev_buf.GetDeviceBuffer(), - b_k_n.get_element_space_size_in_bytes(), - hipMemcpyHostToDevice)); + ADataType* d_A = static_cast(a_m_k_dev_buf.GetDeviceBuffer()); + BDataType* d_B = static_cast(b_k_n_dev_buf.GetDeviceBuffer()); + CDataType* d_C = static_cast(c_m_n_gpu_buf_ref.GetDeviceBuffer()); ck_tile::reference_gemm_gpu(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); - ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(), - d_C, - c_m_n_dev_result.get_element_space_size_in_bytes(), - hipMemcpyDeviceToHost)); - - ck_tile::hip_check_error(hipFree(d_A)); - ck_tile::hip_check_error(hipFree(d_B)); - ck_tile::hip_check_error(hipFree(d_C)); - c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); + const float max_accumulated_value = *std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end()); const auto rtol_atol = calculate_rtol_atol( diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index ecfaa92b9a..c2c3fc1fa4 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -299,6 +299,13 @@ int run_gemm_example(int argc, char* argv[]) ck_tile::bf8_t, ck_tile::half_t>(a_layout, b_layout, argc, argv); } + else if(data_type == "int8") + { + return run_gemm_example_prec_type, + ck_tile::int8_t, + ck_tile::int8_t, + ck_tile::int32_t>(a_layout, b_layout, argc, argv); + } else if(data_type == "pk_int4_t") { // TODO: Add support for bhalf_t ADataType diff --git a/include/ck_tile/core/numeric/integer.hpp b/include/ck_tile/core/numeric/integer.hpp index 3faf3020a6..502026c231 100644 --- a/include/ck_tile/core/numeric/integer.hpp +++ b/include/ck_tile/core/numeric/integer.hpp @@ -7,6 +7,7 @@ namespace ck_tile { using index_t = int32_t; +using int32_t = int32_t; using long_index_t = int64_t; using int8_t = int8_t; diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index 8d19337b86..231a2c832b 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -1009,6 +1009,15 @@ struct buffer_view, int8x8_t>) || (std::is_same_v, int8x16_t> && std::is_same_v, int8x16_t>) || + // int8 on thread buffer + (std::is_same_v, int8_t> && + std::is_same_v, thread_buffer>) || + (std::is_same_v, int8_t> && + std::is_same_v, thread_buffer>) || + (std::is_same_v, int8_t> && + std::is_same_v, thread_buffer>) || + (std::is_same_v, int8_t> && + std::is_same_v, thread_buffer>) || // ext_vector_type for pk_int4 must use int8_t as type (std::is_same_v, pk_int4_t> && std::is_same_v, thread_buffer>) || @@ -1031,6 +1040,8 @@ struct buffer_view, int8_t> && std::is_same_v, int8_t>) || + (std::is_same_v, int8_t> && + std::is_same_v, thread_buffer>) || (std::is_same_v, pk_int4_t> && std::is_same_v, thread_buffer>)) { @@ -1041,6 +1052,8 @@ struct buffer_view, int8_t> && std::is_same_v, int8x2_t>) || + (std::is_same_v, int8_t> && + std::is_same_v, thread_buffer>) || (std::is_same_v, pk_int4_t> && std::is_same_v, thread_buffer>)) { @@ -1051,6 +1064,8 @@ struct buffer_view, int8_t> && std::is_same_v, int8x4_t>) || + (std::is_same_v, int8_t> && + std::is_same_v, thread_buffer>) || (std::is_same_v, pk_int4_t> && std::is_same_v, thread_buffer>)) { @@ -1061,6 +1076,8 @@ struct buffer_view, int8_t> && std::is_same_v, int8x8_t>) || + (std::is_same_v, int8_t> && + std::is_same_v, thread_buffer>) || (std::is_same_v, pk_int4_t> && std::is_same_v, thread_buffer>)) { diff --git a/include/ck_tile/host/host_tensor.hpp b/include/ck_tile/host/host_tensor.hpp index b8c764809c..ecbc009b85 100644 --- a/include/ck_tile/host/host_tensor.hpp +++ b/include/ck_tile/host/host_tensor.hpp @@ -722,6 +722,8 @@ struct HostTensor file << type_convert(itm) << std::endl; else if(dtype == "int") file << type_convert(itm) << std::endl; + else if(dtype == "int8_t") + file << static_cast(type_convert(itm)) << std::endl; else // TODO: we didn't implement operator<< for all custom // data types, here fall back to float in case compile error diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index c4d527da63..d4e23d12dd 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -215,7 +215,7 @@ struct BlockUniversalGemmAsBsCr using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); ALdsTile a_warp_tile_; - ALdsTile b_warp_tile_; + BLdsTile b_warp_tile_; // C += A * B template diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index f243aceda8..185abccd3f 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -282,4 +282,19 @@ using WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution = 2, swizzle_factor>>; +// int8 +using WarpGemmMfma_i32_32x32x16_i8_i8 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; + +using WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed = + WarpGemmImpl>>; + +using WarpGemmMfma_i32_16x16x32_i8_i8 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; + +using WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed = + WarpGemmImpl>>; + } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index 7f7a835a69..80f38f263b 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -1578,8 +1578,8 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8 DISPATCH_MFMA_CTRL_("v_mfma_i32_32x32x16_i8", Ctrl) else { -#if defined(__gfx94__) - c_vec = __builtin_amdgcn_mfma_i32_32x32x8i8( +#if defined(__gfx94__) or defined(__gfx95__) + c_vec = __builtin_amdgcn_mfma_i32_32x32x16_i8( bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); #elif defined(__gfx908__) || defined(__gfx90a__) static_for<0, 8, 1>{}([&](auto k) { @@ -1609,6 +1609,183 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8 } }; +template +struct WarpGemmAttributeMfmaImpl_i32_16x16x32_i8 +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = int8_t; + using BDataType = int8_t; + using CDataType = int32_t; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 16; + static constexpr index_t kN = 16; + static constexpr index_t kK = 32; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABKLane = 4; + static constexpr index_t kABKPerLane = 8; + + static constexpr index_t kCMLane = 4; + static constexpr index_t kCNLane = 16; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 4; // write to 4x AccVGPRs + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + DISPATCH_MFMA_CTRL_("v_mfma_i32_16x16x32_i8", Ctrl) + else + { +#if defined(__gfx94__) or defined(__gfx95__) + c_vec = __builtin_amdgcn_mfma_i32_16x16x32_i8( + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; +#endif + } + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { + CVecType c_vec{0}; + operator()(c_vec, a_vec, b_vec); + return c_vec; + } +}; + +template +struct WarpGemmAttributeMfmaImpl_i32_16x16x64_i8 +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = int8_t; + using BDataType = int8_t; + using CDataType = int32_t; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 16; + static constexpr index_t kN = 16; + static constexpr index_t kK = 64; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABKLane = 4; + static constexpr index_t kABKPerLane = 16; + + static constexpr index_t kCMLane = 4; + static constexpr index_t kCNLane = 16; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 4; // write to 4x AccVGPRs + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + DISPATCH_MFMA_CTRL_("v_mfma_i32_16x16x64_i8", Ctrl) + else + { +#if defined(__gfx95__) + c_vec = __builtin_amdgcn_mfma_i32_16x16x64_i8( + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; +#endif + } + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { + CVecType c_vec{0}; + operator()(c_vec, a_vec, b_vec); + return c_vec; + } +}; + +template +struct WarpGemmAttributeMfmaImpl_i32_32x32x32_i8 +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = int8_t; + using BDataType = int8_t; + using CDataType = int32_t; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 32; + static constexpr index_t kN = 32; + static constexpr index_t kK = 32; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 32; + static constexpr index_t kBNLane = 32; + static constexpr index_t kABKLane = 2; + static constexpr index_t kABKPerLane = 16; + + static constexpr index_t kCMLane = 2; + static constexpr index_t kCNLane = 32; + static constexpr index_t kCM0PerLane = 4; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + DISPATCH_MFMA_CTRL_("v_mfma_i32_32x32x32_i8", Ctrl) + else + { +#if defined(__gfx95__) + c_vec = + __builtin_amdgcn_mfma_i32_32x32x32_i8(a_vec, bit_cast(b_vec), c_vec, 0, 0, 0); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; +#endif + } + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { + CVecType c_vec{0}; + operator()(c_vec, a_vec, b_vec); + return c_vec; + } +}; + #undef DISPATCH_MFMA_ } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index b2f5d56d01..b6ada83532 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -11,7 +11,7 @@ namespace ck_tile { namespace impl { template struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16; }; @@ -37,10 +38,12 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; }; // fp16 2:4 structural sparsity +// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmSmfmacF16F16F32M32N32K16; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmSmfmacF16F16F32M16N16K32; }; // bf16 +// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; }; @@ -56,6 +59,7 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; }; // fp8 +// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; }; @@ -81,12 +85,19 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8; }; +// int8 +// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_i32_32x32x16_i8_i8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_i32_16x16x32_i8_i8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed; }; + // clang-format on } // namespace impl template using WarpGemmMfmaDispatcher = typename impl::WarpGemmMfmaDispatcher; using KernelTypesCompV3 = ::testing::Types< - std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>, - std::tuple< Row, Row, Row, F8, F8, F32, F16, Intrawave, CompV3>, - std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>, - std::tuple< Row, Col, Row, F8, F8, F32, F16, Intrawave, CompV3>, - std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>, - std::tuple< Col, Row, Row, F8, F8, F32, F16, Intrawave, CompV3>, - std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>, - std::tuple< Col, Col, Row, F8, F8, F32, F16, Intrawave, CompV3> + std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>, + std::tuple< Row, Row, Row, F8, F8, F32, F16, Intrawave, CompV3>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>, + std::tuple< Row, Col, Row, F8, F8, F32, F16, Intrawave, CompV3>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>, + std::tuple< Col, Row, Row, F8, F8, F32, F16, Intrawave, CompV3>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>, + std::tuple< Col, Col, Row, F8, F8, F32, F16, Intrawave, CompV3>, + std::tuple< Row, Row, Row, I8, I8, I32, I32, Intrawave, CompV3>, + std::tuple< Row, Col, Row, I8, I8, I32, I32, Intrawave, CompV3>, + std::tuple< Col, Row, Row, I8, I8, I32, I32, Intrawave, CompV3>, + std::tuple< Col, Col, Row, I8, I8, I32, I32, Intrawave, CompV3> + >; using KernelTypesCompV4 = ::testing::Types< diff --git a/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc b/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc index 1f0683f8b8..c824d034a9 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc @@ -32,7 +32,8 @@ TYPED_TEST(TEST_SUITE_NAME, MidLargeM) constexpr int N = 1024; constexpr int K = 320; constexpr int VecLoadSize = (std::is_same_v || - std::is_same_v) + std::is_same_v || + std::is_same_v) ? 16 : 8; @@ -41,7 +42,6 @@ TYPED_TEST(TEST_SUITE_NAME, MidLargeM) if constexpr(std::is_same_v) { - // TODO: Can we anyhow deduce used vector load size? if(M % VecLoadSize == 0) { this->Run(M, N, K); diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index 5f2a53645d..a6a4817143 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -47,6 +47,8 @@ struct GemmPipelineTypeSelector { using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrMem; using pipeline = ck_tile::GemmPipelineAgBgCrMem; + + static constexpr auto GetName() { return "GemmPipelineAgBgCrMem"; } }; template @@ -54,6 +56,8 @@ struct GemmPipelineTypeSelector { using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; using pipeline = ck_tile::GemmPipelineAgBgCrCompV3; + + static constexpr auto GetName() { return "GemmPipelineAgBgCrCompV3"; } }; template @@ -61,6 +65,8 @@ struct GemmPipelineTypeSelector { using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; using pipeline = ck_tile::GemmPipelineAgBgCrCompV4; + + static constexpr auto GetName() { return "GemmPipelineAgBgCrCompV4"; } }; template diff --git a/tile_engine/ops/gemm/CMakeLists.txt b/tile_engine/ops/gemm/CMakeLists.txt index cbba248211..c3c177487f 100644 --- a/tile_engine/ops/gemm/CMakeLists.txt +++ b/tile_engine/ops/gemm/CMakeLists.txt @@ -1,4 +1,3 @@ - # generate a list of kernels, but not actually emit files at config stage execute_process( COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py diff --git a/tile_engine/ops/gemm/codegen_utils.py b/tile_engine/ops/gemm/codegen_utils.py index f16a55ef87..ae496636c6 100644 --- a/tile_engine/ops/gemm/codegen_utils.py +++ b/tile_engine/ops/gemm/codegen_utils.py @@ -11,17 +11,21 @@ import subprocess import re from functools import lru_cache -DATA_TYPE_MAP = {'fp32': 'float', - 'fp16': 'ck_tile::half_t', - 'bf16': 'ck_tile::bf16_t', - 'int8': 'ck_tile::int8_t', - 'fp8': 'ck_tile::fp8_t', - 'bf8': 'ck_tile::bf8_t', - 'int4': 'ck_tile::pk_int4_t' - } +DATA_TYPE_MAP = { + "fp32": "float", + "fp16": "ck_tile::half_t", + "bf16": "ck_tile::bf16_t", + "int8": "ck_tile::int8_t", + "fp8": "ck_tile::fp8_t", + "bf8": "ck_tile::bf8_t", + "int4": "ck_tile::pk_int4_t", + "int32": "ck_tile::int32_t", +} -LAYOUT_MAP = {'r': 'ck_tile::tensor_layout::gemm::RowMajor', - 'c': 'ck_tile::tensor_layout::gemm::ColumnMajor'} +LAYOUT_MAP = { + "r": "ck_tile::tensor_layout::gemm::RowMajor", + "c": "ck_tile::tensor_layout::gemm::ColumnMajor", +} DEFAULT_EPILOGUE = """ using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue< @@ -149,44 +153,109 @@ RUN_COMPV4 = """ """ -PIPELINE_MAP = {'mem': ['ck_tile::BaseGemmPipelineAgBgCrMem', 'ck_tile::GemmPipelineAgBgCrMem'], - 'compv3': ['ck_tile::BaseGemmPipelineAgBgCrCompV3', 'ck_tile::GemmPipelineAgBgCrCompV3'], - 'compv4': ['ck_tile::BaseGemmPipelineAgBgCrCompV4', 'ck_tile::GemmPipelineAgBgCrCompV4']} +PIPELINE_MAP = { + "mem": ["ck_tile::BaseGemmPipelineAgBgCrMem", "ck_tile::GemmPipelineAgBgCrMem"], + "compv3": [ + "ck_tile::BaseGemmPipelineAgBgCrCompV3", + "ck_tile::GemmPipelineAgBgCrCompV3", + ], + "compv4": [ + "ck_tile::BaseGemmPipelineAgBgCrCompV4", + "ck_tile::GemmPipelineAgBgCrCompV4", + ], +} -SCHEDULER_MAP = {'interwave': 'ck_tile::GemmPipelineScheduler::Interwave', - 'intrawave': 'ck_tile::GemmPipelineScheduler::Intrawave'} +SCHEDULER_MAP = { + "interwave": "ck_tile::GemmPipelineScheduler::Interwave", + "intrawave": "ck_tile::GemmPipelineScheduler::Intrawave", +} -EPILOGUE_MAP = {'default': DEFAULT_EPILOGUE, - 'cshuffle': CSHUFFLE_EPILOGUE} +EPILOGUE_MAP = {"default": DEFAULT_EPILOGUE, "cshuffle": CSHUFFLE_EPILOGUE} -HOT_LOOP_TRUE = {'mem': RUN_MEM, - 'compv3': RUN_COMPV3, - 'compv4': RUN_COMPV4} +HOT_LOOP_TRUE = {"mem": RUN_MEM, "compv3": RUN_COMPV3, "compv4": RUN_COMPV4} -def BOOL_MAP(b_): return {True: 'true', False: 'false'}[bool(b_)] +def BOOL_MAP(b_): + return {True: "true", False: "false"}[bool(b_)] # To Do: add some more supported combinations warp_tile_supported_combinations = { "gfx90a": { - 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], - 'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], - 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]], - 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]] + "fp16_fp16_fp16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "bf16_bf16_bf16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]], + "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]], }, "gfx942": { - 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], - 'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], - 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], - 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]] + "fp16_fp16_fp16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "bf16_bf16_bf16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]], + "int8_int8_int32": [[16, 16, 32], [32, 32, 16]], }, "gfx950": { - 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], - 'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], - 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]], - 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]] - } + "fp16_fp16_fp16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "bf16_bf16_bf16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "fp8_fp8_fp16": [ + [32, 32, 16], + [32, 32, 32], + [16, 16, 32], + [16, 16, 64], + [16, 16, 128], + [32, 32, 64], + ], + "fp8_fp8_fp16": [ + [32, 32, 16], + [32, 32, 32], + [16, 16, 64], + [16, 16, 32], + [16, 16, 128], + [32, 32, 64], + ], + }, } # To Do: remove some unsupported combinations @@ -194,24 +263,30 @@ trait_unsupported_combinations = { ("compv3", "cshuffle", "interwave"), ("compv3", "default", "interwave"), ("compv4", "cshuffle", "interwave"), - ("compv4", "default", "interwave") + ("compv4", "default", "interwave"), +} + + +ELEMENT_SIZE_MAP = { + "fp16": 2, + "bf16": 2, + "int8": 1, + "fp8": 1, + "bf8": 1, + "int4": 0.5, + "int32": 4, } def element_size(data_type: str) -> float: """Calculate the size (in bytes) of a single element for given data type.""" data_type = data_type.lower() - if data_type in {'fp16', 'bf16'}: - return 2 - elif data_type in {'int8', 'fp8', 'bf8'}: - return 1 - elif data_type == 'int4': - return 0.5 - else: + if data_type not in ELEMENT_SIZE_MAP: raise ValueError(f"Unsupported data type: {data_type}") + return ELEMENT_SIZE_MAP[data_type] -GPU_NAME_PATTERN = re.compile(r'Name:\s*(gfx\d+\w*)') +GPU_NAME_PATTERN = re.compile(r"Name:\s*(gfx\d+\w*)") @lru_cache(maxsize=1) @@ -219,10 +294,7 @@ def get_gpu_name_by_id(gpu_id: int = 0) -> str: """Retrieve GPU name (e.g. gfx90a) by device ID""" try: output = subprocess.check_output( - ["rocminfo"], - text=True, - stderr=subprocess.PIPE, - timeout=5 + ["rocminfo"], text=True, stderr=subprocess.PIPE, timeout=5 ) if matches := GPU_NAME_PATTERN.finditer(output): gpu_list = [m.group(1) for m in matches] diff --git a/tile_engine/ops/gemm/configs/default_config.json b/tile_engine/ops/gemm/configs/default_config.json index d20c5eef7d..9f71e430de 100644 --- a/tile_engine/ops/gemm/configs/default_config.json +++ b/tile_engine/ops/gemm/configs/default_config.json @@ -33,19 +33,19 @@ }, "tile_config": { "tile_m": { - "max": 512, + "max": 256, "min": 64, "step": 64, "exclude": [] }, "tile_n": { - "max": 512, + "max": 256, "min": 64, "step": 32, "exclude": [] }, "tile_k": { - "max": 512, + "max": 256, "min": 64, "step": 64, "exclude": [192] diff --git a/tile_engine/ops/gemm/configs/user_provided_config.json b/tile_engine/ops/gemm/configs/user_provided_config.json index 6a6e726e40..43c8784667 100644 --- a/tile_engine/ops/gemm/configs/user_provided_config.json +++ b/tile_engine/ops/gemm/configs/user_provided_config.json @@ -17,17 +17,17 @@ }, "datatype_a": { "values": [ - "fp16" + "int8" ] }, "datatype_b": { "values": [ - "fp16" + "int8" ] }, "datatype_c": { "values": [ - "fp16" + "int32" ] } }, @@ -44,7 +44,7 @@ }, "tile_k": { "values": [ - 32 + 128 ] }, "warp_m": { @@ -64,17 +64,17 @@ }, "warp_tile_m": { "values": [ - 32 + 16, 32 ] }, "warp_tile_n": { "values": [ - 32 + 16, 32 ] }, "warp_tile_k": { "values": [ - 16 + 16, 32 ] } }, diff --git a/tile_engine/ops/gemm/gemm_host_api.hpp b/tile_engine/ops/gemm/gemm_host_api.hpp index b3aab6ad92..2c4af8955f 100644 --- a/tile_engine/ops/gemm/gemm_host_api.hpp +++ b/tile_engine/ops/gemm/gemm_host_api.hpp @@ -50,6 +50,18 @@ struct DataTypeTraits static constexpr const char* name = "bf8"; }; +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int32"; +}; + template <> struct DataTypeTraits { diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index e7690ac481..f217522feb 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -29,10 +29,9 @@ from codegen_utils import ( warp_tile_supported_combinations, trait_unsupported_combinations, element_size, - get_gpu_name_by_id + get_gpu_name_by_id, ) import logging -import time logging.basicConfig(level=logging.INFO) @@ -40,16 +39,18 @@ logging.basicConfig(level=logging.INFO) class GemmCodeGenerator: """GEMM (General Matrix Multiplication) code generator.""" - def __init__(self, output_dir: str, - user_provided_config: Optional[GemmConfig] = None): + def __init__( + self, output_dir: str, user_provided_config: Optional[GemmConfig] = None + ): self.output_dir = Path(output_dir) self.output_dir.mkdir(parents=True, exist_ok=True) if user_provided_config is not None: self.config = user_provided_config else: - config_path = Path(__file__).resolve().parent / \ - "configs" / "default_config.json" + config_path = ( + Path(__file__).resolve().parent / "configs" / "default_config.json" + ) self.config = GemmConfig.from_json(config_path) self.valid_trait_names: List[str] = [] @@ -58,46 +59,82 @@ class GemmCodeGenerator: def list_all_trait_names(self): """List all possible kernel trait names into file.""" w_p = Path(self.output_dir) - file_path = w_p / 'gemm_instance_blobs.txt' + file_path = w_p / "gemm_instance_blobs.txt" self._generate_all_traits() self._get_valid_trait_tile_combinations() # Write all file paths to the header file - with file_path.open('w') as f: - f.write(str(w_p / "gemm_common.hpp") + "\n") - f.write(str(w_p / "gemm_instances.hpp") + "\n") - f.write(str(w_p / "gemm_dispatcher.hpp") + "\n") + files_listed = 0 + with file_path.open("w") as f: + # Core files + core_files = [ + "gemm_common.hpp", + "gemm_instances.hpp", + "gemm_dispatcher.hpp", + ] + for core_file in core_files: + f.write(str(w_p / core_file) + "\n") + files_listed += 1 + + # Trait header files for trait in self.valid_trait_names: - f.write(str(w_p / f"gemm_{trait}.hpp") + "\n") + trait_file = f"gemm_{trait}.hpp" + f.write(str(w_p / trait_file) + "\n") + files_listed += 1 + + # Instance source files for trait, tile_valid_params in self.valid_trait_tile_combinations.items(): for tile in tile_valid_params: - for tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k in tile: - sparse = self.config.problem.datatype_map['matrix_a'] == 'fp16' and \ - self.config.problem.datatype_map['matrix_b'] == 'fp16' and \ - self.config.problem.datatype_map['matrix_c'] == 'fp16' and \ - ((warp_tile_m == 32 and warp_tile_n == 32 and warp_tile_k == 16) or - (warp_tile_m == 16 and warp_tile_n == 16 and warp_tile_k == 32)) + for ( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + ) in tile: + instance_name = f"{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}" + sparse = ( + self.config.problem.datatype_map["matrix_a"] == "fp16" + and self.config.problem.datatype_map["matrix_b"] == "fp16" + and self.config.problem.datatype_map["matrix_c"] == "fp16" + and ( + ( + warp_tile_m == 32 + and warp_tile_n == 32 + and warp_tile_k == 16 + ) + or ( + warp_tile_m == 16 + and warp_tile_n == 16 + and warp_tile_k == 32 + ) + ) + ) if sparse: - f.write(str( - w_p / f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}_true.cpp") + "\n") - f.write(str( - w_p / f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}_false.cpp") + "\n") + sparse_file = f"gemm_{trait}_{instance_name}_true.cpp" + f.write(str(w_p / sparse_file) + "\n") + files_listed += 1 + + regular_file = f"gemm_{trait}_{instance_name}_false.cpp" + f.write(str(w_p / regular_file) + "\n") + files_listed += 1 + + print(f"File listing complete: {files_listed} files listed in {file_path}\n") def _generate_all_traits(self): """Generate all possible kernel traits names.""" - params = [ - "pipeline", - "epilogue", - "scheduler", - "pad_m", - "pad_n", - "pad_k"] + params = ["pipeline", "epilogue", "scheduler", "pad_m", "pad_n", "pad_k"] # Generate all unique_combinations - _unique = set(itertools.product(*[ - getattr(self.config.trait_config, param).values - for param in params - ])) + _unique = set( + itertools.product( + *[getattr(self.config.trait_config, param).values for param in params] + ) + ) for combo in _unique: pipeline, epilogue, scheduler, pad_m, pad_n, pad_k = combo @@ -110,9 +147,7 @@ class GemmCodeGenerator: ) self.valid_trait_names.append(trait_name) else: - logging.debug( - f"Invalid combination: {pipeline}-{epilogue}-{scheduler}" - ) + logging.debug(f"Invalid combination: {pipeline}-{epilogue}-{scheduler}") def generate_all_instance_files(self): """Generate all kernel instances files.""" @@ -123,6 +158,16 @@ class GemmCodeGenerator: def _generate_common_header_file(self): """Generate common header file with datatypes and layout.""" + # Determine appropriate accumulation type based on input types + a_type = self.config.problem.datatype_map["matrix_a"] + b_type = self.config.problem.datatype_map["matrix_b"] + c_type = self.config.problem.datatype_map["matrix_c"] + + if a_type in ["int8", "int4"] and b_type in ["int8", "int4"]: + acc_type = "ck_tile::int32_t" + else: + acc_type = "float" + content = f"""// SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. @@ -132,15 +177,15 @@ class GemmCodeGenerator: #include "ck_tile/ops/common.hpp" // Data types -using ADataType = {DATA_TYPE_MAP[self.config.problem.datatype_map['matrix_a']]}; -using BDataType = {DATA_TYPE_MAP[self.config.problem.datatype_map['matrix_b']]}; -using AccDataType = float; -using CDataType = {DATA_TYPE_MAP[self.config.problem.datatype_map['matrix_c']]}; +using ADataType = {DATA_TYPE_MAP[self.config.problem.datatype_map["matrix_a"]]}; +using BDataType = {DATA_TYPE_MAP[self.config.problem.datatype_map["matrix_b"]]}; +using AccDataType = {acc_type}; +using CDataType = {DATA_TYPE_MAP[self.config.problem.datatype_map["matrix_c"]]}; // Layout configurations -using ALayout = {LAYOUT_MAP[self.config.problem.layout_map['matrix_a']]}; -using BLayout = {LAYOUT_MAP[self.config.problem.layout_map['matrix_b']]}; -using CLayout = {LAYOUT_MAP[self.config.problem.layout_map['matrix_c']]}; +using ALayout = {LAYOUT_MAP[self.config.problem.layout_map["matrix_a"]]}; +using BLayout = {LAYOUT_MAP[self.config.problem.layout_map["matrix_b"]]}; +using CLayout = {LAYOUT_MAP[self.config.problem.layout_map["matrix_c"]]}; """ (self.output_dir / "gemm_common.hpp").write_text(content) @@ -174,13 +219,21 @@ namespace {trait} {{ """ # Add template struct with configuration content += self._generate_kernel_struct( - pipeline, epilogue, scheduler, pad_m, pad_n, pad_k) + pipeline, epilogue, scheduler, pad_m, pad_n, pad_k + ) content += f"\n}} // namespace {trait}\n" (self.output_dir / filename).write_text(content) - def _generate_kernel_struct(self, pipeline: str, epilogue: str, scheduler: str, - pad_m: str, pad_n: str, pad_k: str) -> str: + def _generate_kernel_struct( + self, + pipeline: str, + epilogue: str, + scheduler: str, + pad_m: str, + pad_n: str, + pad_k: str, + ) -> str: """Generate the code block of kernel struct""" return f""" @@ -193,7 +246,7 @@ struct GemmKernel {{ static constexpr bool kPadN = {pad_n}; static constexpr bool kPadK = {pad_k}; - static float launch(ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream) {{ + static float launch(ck_tile::GemmHostArgs<><>& args, const ck_tile::stream_config& stream) {{ static constexpr bool permuteA = false; static constexpr bool permuteB = false; static constexpr bool DoubleSmemBuffer ={"true" if pipeline == "compv4" else "false"}; @@ -307,6 +360,7 @@ struct GemmKernel {{ if(args.k_batch > 1) hipGetErrorString(hipMemsetAsync( args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream.stream_id_)); + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream.stream_id_)); }}; ave_time = ck_tile::launch_kernel_preprocess( stream, @@ -367,28 +421,36 @@ struct GemmKernel {{ #pragma once """ for trait in self.valid_trait_names: - content += f"#include \"gemm_{trait}.hpp\"\n" + content += f'#include "gemm_{trait}.hpp"\n' (self.output_dir / "gemm_instances.hpp").write_text(content) def is_tile_valid(self, tile: tuple, trait: str) -> bool: """Check if the tile configuration is valid for the given trait.""" - tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k = tile + ( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + ) = tile pipeline, *_ = trait.split("_") # Parameter validity check invalid_params = [] if (warp_m, warp_n, warp_k) not in [(1, 4, 1), (2, 2, 1), (4, 1, 1)]: invalid_params.append( - f"warp_m({warp_m}) * warp_n({warp_n}) * warp_k({warp_k})") + f"warp_m({warp_m}) * warp_n({warp_n}) * warp_k({warp_k})" + ) if (warp_m * warp_tile_m) == 0: - invalid_params.append( - f"warp_m({warp_m}) * warp_tile_m({warp_tile_m})") + invalid_params.append(f"warp_m({warp_m}) * warp_tile_m({warp_tile_m})") if (warp_n * warp_tile_n) == 0: - invalid_params.append( - f"warp_n({warp_n}) * warp_tile_n({warp_tile_n})") + invalid_params.append(f"warp_n({warp_n}) * warp_tile_n({warp_tile_n})") if (warp_k * warp_tile_k) == 0: - invalid_params.append( - f"warp_k({warp_k}) * warp_tile_k({warp_tile_k})") + invalid_params.append(f"warp_k({warp_k}) * warp_tile_k({warp_tile_k})") if invalid_params: logging.debug( @@ -397,18 +459,20 @@ struct GemmKernel {{ f"warp_tile=({warp_tile_m},{warp_tile_n},{warp_tile_k})" ) return False - # Dimension alignment check alignment_issues = [] if tile_m % (warp_m * warp_tile_m) != 0: alignment_issues.append( - f"tile_m({tile_m}) % [{warp_m}x{warp_tile_m}] = {tile_m % (warp_m * warp_tile_m)}") + f"tile_m({tile_m}) % [{warp_m}x{warp_tile_m}] = {tile_m % (warp_m * warp_tile_m)}" + ) if tile_n % (warp_n * warp_tile_n) != 0: alignment_issues.append( - f"tile_n({tile_n}) % [{warp_n}x{warp_tile_n}] = {tile_n % (warp_n * warp_tile_n)}") + f"tile_n({tile_n}) % [{warp_n}x{warp_tile_n}] = {tile_n % (warp_n * warp_tile_n)}" + ) if tile_k % (warp_k * warp_tile_k) != 0: alignment_issues.append( - f"tile_k({tile_k}) % [{warp_k}x{warp_tile_k}] = {tile_k % (warp_k * warp_tile_k)}") + f"tile_k({tile_k}) % [{warp_k}x{warp_tile_k}] = {tile_k % (warp_k * warp_tile_k)}" + ) if alignment_issues: logging.debug( @@ -419,17 +483,20 @@ struct GemmKernel {{ return False # LDS capacity verification - matrix_a_size = (tile_m * tile_k) * \ - element_size(self.config.problem.datatype_map['matrix_a']) - matrix_b_size = (tile_n * tile_k) * \ - element_size(self.config.problem.datatype_map['matrix_b']) + matrix_a_size = (tile_m * tile_k) * element_size( + self.config.problem.datatype_map["matrix_a"] + ) + matrix_b_size = (tile_n * tile_k) * element_size( + self.config.problem.datatype_map["matrix_b"] + ) total_tile_in_lds = matrix_a_size + matrix_b_size max_tile_size = 2**15 if pipeline == "compv4" else 2**16 + if total_tile_in_lds > max_tile_size: logging.debug( - f"LDS capacity exceeded [{trait}]: Total required {total_tile_in_lds:,}B ({total_tile_in_lds/1024:.1f}KB) > " - f"maximum allowed {max_tile_size:,}B ({max_tile_size/1024}KB). Breakdown:\n" + f"LDS capacity exceeded [{trait}]: Total required {total_tile_in_lds:,}B ({total_tile_in_lds / 1024:.1f}KB) > " + f"maximum allowed {max_tile_size:,}B ({max_tile_size / 1024}KB). Breakdown:\n" f"- Matrix A ({self.config.problem.datatype_map['matrix_a']}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n" f"- Matrix B ({self.config.problem.datatype_map['matrix_b']}): {tile_n}x{tile_k} = {matrix_b_size:,}B" ) @@ -440,16 +507,19 @@ struct GemmKernel {{ current_combination = [warp_tile_m, warp_tile_n, warp_tile_k] gpu_name = get_gpu_name_by_id(0) + gpu_warp_tile_key = warp_tile_supported_combinations.get(gpu_name, {}) if not gpu_warp_tile_key: logging.debug( - f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check.") + f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check." + ) return False allowed_combinations = gpu_warp_tile_key.get(warp_tile_key, []) if not allowed_combinations: logging.debug( - f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check.") + f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check." + ) return False if current_combination not in allowed_combinations: @@ -462,49 +532,68 @@ struct GemmKernel {{ return True def _get_valid_trait_tile_combinations(self): - def get_tile_value(tile_param): return tile_param.generate_candidates( - ) if isinstance(tile_param, RangeConfigParam) else tile_param.values + def get_tile_value(tile_param): + return ( + tile_param.generate_candidates() + if isinstance(tile_param, RangeConfigParam) + else tile_param.values + ) - tile_group = list(itertools.product( - get_tile_value(self.config.tile_config.tile_m), - get_tile_value(self.config.tile_config.tile_n), - get_tile_value(self.config.tile_config.tile_k) - )) + tile_group = list( + itertools.product( + get_tile_value(self.config.tile_config.tile_m), + get_tile_value(self.config.tile_config.tile_n), + get_tile_value(self.config.tile_config.tile_k), + ) + ) - warp_group = list(itertools.product( - get_tile_value(self.config.tile_config.warp_m), - get_tile_value(self.config.tile_config.warp_n), - get_tile_value(self.config.tile_config.warp_k) - )) + warp_group = list( + itertools.product( + get_tile_value(self.config.tile_config.warp_m), + get_tile_value(self.config.tile_config.warp_n), + get_tile_value(self.config.tile_config.warp_k), + ) + ) - warp_tile_group = list(itertools.product( - get_tile_value(self.config.tile_config.warp_tile_m), - get_tile_value(self.config.tile_config.warp_tile_n), - get_tile_value(self.config.tile_config.warp_tile_k) - )) + warp_tile_group = list( + itertools.product( + get_tile_value(self.config.tile_config.warp_tile_m), + get_tile_value(self.config.tile_config.warp_tile_n), + get_tile_value(self.config.tile_config.warp_tile_k), + ) + ) tile_params = { - t + w + wt - for t in tile_group - for w in warp_group - for wt in warp_tile_group + t + w + wt for t in tile_group for w in warp_group for wt in warp_tile_group } for trait in self.valid_trait_names: - tile_valid_params = list( - filter(lambda t: self.is_tile_valid(t, trait), tile_params)) + tile_valid_params = [ + tile for tile in tile_params if self.is_tile_valid(tile, trait) + ] - # if len(tile_valid_params) == 0: - # raise RuntimeError(f"No valid kernel instance selected for trait: {trait}") if trait not in self.valid_trait_tile_combinations: self.valid_trait_tile_combinations[trait] = [] self.valid_trait_tile_combinations[trait].append(tile_valid_params) def _generate_instantiation_source_files(self): - """Generate kernel instance instantiation source files """ + """Generate kernel instance instantiation source files""" + for trait, tile_valid_params in self.valid_trait_tile_combinations.items(): for tile in tile_valid_params: - for tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k in tile: + for ( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + ) in tile: + instance_name = f"{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}" + content = f""" // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. @@ -514,23 +603,41 @@ struct GemmKernel {{ #include "gemm_{trait}.hpp" """ - sparse = self.config.problem.datatype_map['matrix_a'] == 'fp16' and \ - self.config.problem.datatype_map['matrix_b'] == 'fp16' and \ - self.config.problem.datatype_map['matrix_c'] == 'fp16' and \ - ((warp_tile_m == 32 and warp_tile_n == 32 and warp_tile_k == 16) or - (warp_tile_m == 16 and warp_tile_n == 16 and warp_tile_k == 32)) + sparse = ( + self.config.problem.datatype_map["matrix_a"] == "fp16" + and self.config.problem.datatype_map["matrix_b"] == "fp16" + and self.config.problem.datatype_map["matrix_c"] == "fp16" + and ( + ( + warp_tile_m == 32 + and warp_tile_n == 32 + and warp_tile_k == 16 + ) + or ( + warp_tile_m == 16 + and warp_tile_n == 16 + and warp_tile_k == 32 + ) + ) + ) if sparse: - sparse_content = content + f""" + sparse_filename = f"gemm_{trait}_{instance_name}_true.cpp" + sparse_content = ( + content + + f""" template struct {trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, true>; """ - (self.output_dir / - f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}_true.cpp").write_text(sparse_content) + ) + (self.output_dir / sparse_filename).write_text(sparse_content) - no_sparse_content = content + f""" + no_sparse_filename = f"gemm_{trait}_{instance_name}_false.cpp" + no_sparse_content = ( + content + + f""" template struct {trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, false>; """ - (self.output_dir / - f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}_false.cpp").write_text(no_sparse_content) + ) + (self.output_dir / no_sparse_filename).write_text(no_sparse_content) def _generate_dispatcher_file(self): """Generate the code block of dispatch mechanism.""" @@ -576,7 +683,7 @@ struct GemmDispatcher { } static void init(bool structured_sparsity) { - ck_tile::ignore = structured_sparsity; + (void)structured_sparsity; // Suppress unused parameter warning auto& kernel_map = get_kernel_map(); if(!kernel_map.empty()) return; \n""" @@ -585,16 +692,37 @@ struct GemmDispatcher { content += f""" kernel_map["{trait}"] = {{""" for _, tile in enumerate(tile_valid_params): for j in range(len(tile)): - tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k = tile[ - j] - content += f"""[=](ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream) {{ """ + ( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + ) = tile[j] + content += f"""[=](ck_tile::GemmHostArgs<><>& args, const ck_tile::stream_config& stream) {{ """ content += f""" if(structured_sparsity){{ // SMFMA""" - sparse = self.config.problem.datatype_map['matrix_a'] == 'fp16' and \ - self.config.problem.datatype_map['matrix_b'] == 'fp16' and \ - self.config.problem.datatype_map['matrix_c'] == 'fp16' and \ - ((warp_tile_m == 32 and warp_tile_n == 32 and warp_tile_k == 16) or - (warp_tile_m == 16 and warp_tile_n == 16 and warp_tile_k == 32)) + sparse = ( + self.config.problem.datatype_map["matrix_a"] == "fp16" + and self.config.problem.datatype_map["matrix_b"] == "fp16" + and self.config.problem.datatype_map["matrix_c"] == "fp16" + and ( + ( + warp_tile_m == 32 + and warp_tile_n == 32 + and warp_tile_k == 16 + ) + or ( + warp_tile_m == 16 + and warp_tile_n == 16 + and warp_tile_k == 32 + ) + ) + ) content += f""" return run_kernel<{trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, {BOOL_MAP(sparse)}>>(args, stream);""" content += f""" @@ -604,7 +732,7 @@ struct GemmDispatcher { content += f""" }} """ - if j == len(tile)-1: + if j == len(tile) - 1: content += f""" }} """ else: @@ -651,22 +779,26 @@ private: (self.output_dir / "gemm_dispatcher.hpp").write_text(content) -def do_list_blobs(args: argparse.Namespace, - user_provide_config: Optional[GemmConfig] = None): +def do_list_blobs( + args: argparse.Namespace, user_provide_config: Optional[GemmConfig] = None +): generator = GemmCodeGenerator(args.working_path, user_provide_config) generator.list_all_trait_names() -def do_gen_blobs(args: argparse.Namespace, - user_provide_config: Optional[GemmConfig] = None): +def do_gen_blobs( + args: argparse.Namespace, user_provide_config: Optional[GemmConfig] = None +): generator = GemmCodeGenerator(args.working_path, user_provide_config) generator.generate_all_instance_files() def main(args): - - gemm_config = GemmConfig.from_json( - args.config_json) if args.config_json is not None else args.config_json + gemm_config = ( + GemmConfig.from_json(args.config_json) + if args.config_json is not None + else args.config_json + ) if args.list_blobs: do_list_blobs(args, gemm_config) @@ -674,7 +806,8 @@ def main(args): do_gen_blobs(args, gemm_config) else: logging.warning( - "No mode specified (use --list_blobs or --gen_blobs). Generating by default...") + "No mode specified (use --list_blobs or --gen_blobs). Generating by default..." + ) do_gen_blobs(args, gemm_config) @@ -684,16 +817,29 @@ if __name__ == "__main__": description="gen API for CK gemm kernel", ) parser.add_argument( - "-w", "--working_path", default="./", required=False, help="The path where all the blobs are going to be generated" + "-w", + "--working_path", + default="./", + required=False, + help="The path where all the blobs are going to be generated", ) parser.add_argument( - "-j", "--config_json", required=False, help="Path to the json which contains the configurations that user provide" + "-j", + "--config_json", + required=False, + help="Path to the json which contains the configurations that user provide", ) parser.add_argument( - "-l", "--list_blobs", action='store_true', help="List all kernel instances to file" + "-l", + "--list_blobs", + action="store_true", + help="List all kernel instances to file", ) parser.add_argument( - "-g", "--gen_blobs", action='store_true', help="Generate all kernel instances into different files" + "-g", + "--gen_blobs", + action="store_true", + help="Generate all kernel instances into different files", ) args = parser.parse_args() diff --git a/tile_engine/ops/gemm/gemm_profiler.hpp b/tile_engine/ops/gemm/gemm_profiler.hpp index 0fd87ec07d..272799e4d6 100644 --- a/tile_engine/ops/gemm/gemm_profiler.hpp +++ b/tile_engine/ops/gemm/gemm_profiler.hpp @@ -23,6 +23,7 @@ class GemmProfiler void benchmark(GemmProblem& gemm_problem, std::vector( ck_tile::GemmHostArgs<>&, const ck_tile::stream_config&)>>& callables) + ck_tile::GemmHostArgs<>&, const ck_tile::stream_config&)>>& callables) { const ALayout layout_a = ALayout{}; const BLayout layout_b = BLayout{}; @@ -89,17 +90,20 @@ class GemmProfiler c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - ck_tile::GemmHostArgs<> gemm_args; - gemm_args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); - gemm_args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); - gemm_args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); - gemm_args.k_batch = gemm_problem.split_k_; - gemm_args.M = gemm_problem.m_; - gemm_args.N = gemm_problem.n_; - gemm_args.K = gemm_problem.k_; - gemm_args.stride_A = gemm_problem.stride_a_; - gemm_args.stride_B = gemm_problem.stride_b_; - gemm_args.stride_C = gemm_problem.stride_c_; + ck_tile::GemmHostArgs<> gemm_args = { + a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + {}, // ds_ptr + c_m_n_dev_buf.GetDeviceBuffer(), + gemm_problem.split_k_, + gemm_problem.m_, + gemm_problem.n_, + gemm_problem.k_, + gemm_problem.stride_a_, + gemm_problem.stride_b_, + {}, // stride_Ds + gemm_problem.stride_c_, + }; ck_tile::HostTensor c_m_n_host_result(ck_tile::host_tensor_descriptor( gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c))); diff --git a/tile_engine/ops/gemm/json_config.py b/tile_engine/ops/gemm/json_config.py index 597caba76f..aaf732c6a8 100644 --- a/tile_engine/ops/gemm/json_config.py +++ b/tile_engine/ops/gemm/json_config.py @@ -16,12 +16,14 @@ import json @dataclass class EnumConfigParam: """Represents an enumeration-type configuration parameter""" + values: List[Union[int, str, bool]] @dataclass class RangeConfigParam: """Represents a numeric range-type configuration parameter""" + min: int max: int step: int @@ -31,17 +33,13 @@ class RangeConfigParam: """Generates valid candidates after applying range constraints""" if self.min > self.max: - raise ValueError( - f"Invalid range: min({self.min}) > max({self.max})" - ) + raise ValueError(f"Invalid range: min({self.min}) > max({self.max})") if self.step <= 0: - raise ValueError( - f"Step must be positive, got {self.step}" - ) + raise ValueError(f"Step must be positive, got {self.step}") candidates = list(range(self.min, self.max + 1, self.step)) - if hasattr(self, 'exclude') and self.exclude: + if hasattr(self, "exclude") and self.exclude: if not isinstance(self.exclude, list): raise TypeError("exclude must be list type") exclude_set = set(self.exclude) @@ -59,6 +57,7 @@ class RangeConfigParam: @dataclass class ProblemConfig: """configuration class for problem parameter.""" + datatypes: Tuple[EnumConfigParam, ...] layouts: Tuple[EnumConfigParam, ...] @@ -66,24 +65,25 @@ class ProblemConfig: def datatype_map(self) -> Dict[str, str]: """Get datatype as a key-value map.""" return { - 'matrix_a': self.datatypes[0].values[0], - 'matrix_b': self.datatypes[1].values[0], - 'matrix_c': self.datatypes[2].values[0] + "matrix_a": self.datatypes[0].values[0], + "matrix_b": self.datatypes[1].values[0], + "matrix_c": self.datatypes[2].values[0], } @property def layout_map(self) -> Dict[str, str]: """Get layout as a key-value map.""" return { - 'matrix_a': self.layouts[0].values[0], - 'matrix_b': self.layouts[1].values[0], - 'matrix_c': self.layouts[2].values[0] + "matrix_a": self.layouts[0].values[0], + "matrix_b": self.layouts[1].values[0], + "matrix_c": self.layouts[2].values[0], } @dataclass class TileConfig: """Configuration class for tile parameter.""" + tile_m: Union[EnumConfigParam, RangeConfigParam] tile_n: Union[EnumConfigParam, RangeConfigParam] tile_k: Union[EnumConfigParam, RangeConfigParam] @@ -100,6 +100,7 @@ class TileConfig: @dataclass class TraitConfig: """Configuration class for kernel traits.""" + pipeline: EnumConfigParam scheduler: EnumConfigParam epilogue: EnumConfigParam @@ -110,7 +111,8 @@ class TraitConfig: @dataclass class GemmConfig: - """Main configuration class for GEMM operations """ + """Main configuration class for GEMM operations""" + problem: ProblemConfig tile_config: TileConfig trait_config: TraitConfig @@ -124,76 +126,83 @@ class GemmConfig: if not config_path.exists(): raise FileNotFoundError(f"Config file {filepath} not found") - with config_path.open('r') as f: + with config_path.open("r") as f: config_dict = json.load(f) # Parse problem config problem = ProblemConfig( datatypes=( EnumConfigParam( - values=config_dict['problem']['datatype_a']['values']), + values=config_dict["problem"]["datatype_a"]["values"] + ), EnumConfigParam( - values=config_dict['problem']['datatype_b']['values']), + values=config_dict["problem"]["datatype_b"]["values"] + ), EnumConfigParam( - values=config_dict['problem']['datatype_c']['values']) + values=config_dict["problem"]["datatype_c"]["values"] + ), ), layouts=( EnumConfigParam( - values=config_dict['problem']['layout_a']['values']), + values=config_dict["problem"]["layout_a"]["values"] + ), EnumConfigParam( - values=config_dict['problem']['layout_b']['values']), + values=config_dict["problem"]["layout_b"]["values"] + ), EnumConfigParam( - values=config_dict['problem']['layout_c']['values']) - ) + values=config_dict["problem"]["layout_c"]["values"] + ), + ), ) # Parse tile config def create_param(param_dict): - if 'values' in param_dict: - return EnumConfigParam(values=param_dict['values']) + if "values" in param_dict: + return EnumConfigParam(values=param_dict["values"]) else: return RangeConfigParam( - min=param_dict['min'], - max=param_dict['max'], - step=param_dict['step'], - exclude=param_dict.get('exclude', []) + min=param_dict["min"], + max=param_dict["max"], + step=param_dict["step"], + exclude=param_dict.get("exclude", []), ) tile_config = TileConfig( - tile_m=create_param(config_dict['tile_config']['tile_m']), - tile_n=create_param(config_dict['tile_config']['tile_n']), - tile_k=create_param(config_dict['tile_config']['tile_k']), - warp_m=create_param(config_dict['tile_config']['warp_m']), - warp_n=create_param(config_dict['tile_config']['warp_n']), - warp_k=create_param(config_dict['tile_config']['warp_k']), - warp_tile_m=create_param( - config_dict['tile_config']['warp_tile_m']), - warp_tile_n=create_param( - config_dict['tile_config']['warp_tile_n']), - warp_tile_k=create_param( - config_dict['tile_config']['warp_tile_k']) + tile_m=create_param(config_dict["tile_config"]["tile_m"]), + tile_n=create_param(config_dict["tile_config"]["tile_n"]), + tile_k=create_param(config_dict["tile_config"]["tile_k"]), + warp_m=create_param(config_dict["tile_config"]["warp_m"]), + warp_n=create_param(config_dict["tile_config"]["warp_n"]), + warp_k=create_param(config_dict["tile_config"]["warp_k"]), + warp_tile_m=create_param(config_dict["tile_config"]["warp_tile_m"]), + warp_tile_n=create_param(config_dict["tile_config"]["warp_tile_n"]), + warp_tile_k=create_param(config_dict["tile_config"]["warp_tile_k"]), ) # Parse trait config trait_config = TraitConfig( pipeline=EnumConfigParam( - values=config_dict['trait_config']['pipeline']['values']), + values=config_dict["trait_config"]["pipeline"]["values"] + ), scheduler=EnumConfigParam( - values=config_dict['trait_config']['scheduler']['values']), + values=config_dict["trait_config"]["scheduler"]["values"] + ), epilogue=EnumConfigParam( - values=config_dict['trait_config']['epilogue']['values']), + values=config_dict["trait_config"]["epilogue"]["values"] + ), pad_m=EnumConfigParam( - values=config_dict['trait_config']['pad_m']['values']), + values=config_dict["trait_config"]["pad_m"]["values"] + ), pad_n=EnumConfigParam( - values=config_dict['trait_config']['pad_n']['values']), + values=config_dict["trait_config"]["pad_n"]["values"] + ), pad_k=EnumConfigParam( - values=config_dict['trait_config']['pad_k']['values']) + values=config_dict["trait_config"]["pad_k"]["values"] + ), ) return cls( - problem=problem, - tile_config=tile_config, - trait_config=trait_config + problem=problem, tile_config=tile_config, trait_config=trait_config ) except json.JSONDecodeError as e: