diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 0081edcb2e..225997439e 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -49,8 +49,9 @@ struct CShuffleEpilogue using BDataType = remove_cvref_t; using AccDataType = remove_cvref_t; using ODataType = remove_cvref_t; + // Used for weight-only quantization kernel, B would be dequantized to the same data type as A using BTypeToUse = - std::conditional_t, ODataType, BDataType>; + std::conditional_t, ADataType, BDataType>; using CLayout = remove_cvref_t; static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kMPerBlock = Problem::kMPerBlock; diff --git a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp index 6e290fe6d7..1d6a99eb4b 100644 --- a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp @@ -25,7 +25,9 @@ struct Default2DEpilogueProblem static constexpr bool UseRawStore = UseRawStore_; }; -template { + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; using CLayout = remove_cvref_t; static constexpr index_t kMPerXdl = kMPerXdl_; static constexpr index_t kNPerXdl = kNPerXdl_; @@ -96,17 +100,22 @@ struct Default2DEpilogue template struct DefaultGemm2DEpilogue : public Default2DEpilogue { - using Problem = remove_cvref_t; - using AccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; + using Problem = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + // Used for weight-only quantization kernel, B would be dequantized to the same data type as A + using BTypeToUse = + std::conditional_t, ADataType, BDataType>; using CLayout = remove_cvref_t; static constexpr index_t kMPerXdl = Problem::kMPerXdl; static constexpr index_t kNPerXdl = Problem::kNPerXdl; static constexpr index_t kKPerXdl = Problem::kKPerXdl; static constexpr index_t isCTransposed = Problem::isCTransposed; - using WG = WarpGemmMfmaDispatcher(BLdsTileDistr)); ALdsTile a_warp_tile_; - ALdsTile b_warp_tile_; + BLdsTile b_warp_tile_; template CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 4732027e57..22962b9404 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -216,6 +216,18 @@ using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl>>; +using WarpGemmMfma_f32_16x16x128_fp8_fp8 = WarpGemmImpl>>; + +using WarpGemmMfma_f32_16x16x128_fp8_bf8 = WarpGemmImpl>>; + +using WarpGemmMfma_f32_16x16x128_bf8_fp8 = WarpGemmImpl>>; + +using WarpGemmMfma_f32_16x16x128_bf8_bf8 = WarpGemmImpl>>; + using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed = WarpGemmImpl>>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index 08f813a1e3..cd32f35180 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -1342,6 +1342,104 @@ template using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8 = WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base; +template +struct WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = AType_; + using BDataType = BType_; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 16; + static constexpr index_t kN = 16; + static constexpr index_t kK = 128; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABKLane = 4; + static constexpr index_t kABKPerLane = 32; + + static constexpr index_t kCMLane = 4; + static constexpr index_t kCNLane = 16; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + //__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a, + // opsel, scale_b) +#if defined(__gfx950__) + if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + a_vec, b_vec, c_vec, 0, 0, 0, 0, 0, 0); + else if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + a_vec, b_vec, c_vec, 0, 1, 0, 0, 0, 0); + else if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + a_vec, b_vec, c_vec, 1, 0, 0, 0, 0, 0); + else if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + a_vec, b_vec, c_vec, 1, 1, 0, 0, 0, 0); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; +#endif + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { +#if defined(__gfx950__) + if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + a_vec, b_vec, CVecType{0.f}, 0, 0, 0, 0, 0, 0)); + else if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + a_vec, b_vec, CVecType{0.f}, 0, 1, 0, 0, 0, 0)); + else if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + a_vec, b_vec, CVecType{0.f}, 1, 0, 0, 0, 0, 0)); + else if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + a_vec, b_vec, CVecType{0.f}, 1, 1, 0, 0, 0, 0)); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + return CVecType{0.f}; +#endif + } +}; + +template +using WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_fp8 = + WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base; + +template +using WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_bf8 = + WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base; + +template +using WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_fp8 = + WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base; + +template +using WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_bf8 = + WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base; + // int8 template struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8 diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index f437ee10c5..0e3342c479 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -69,6 +69,11 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8; }; + // clang-format on } // namespace impl diff --git a/tile_engine/ops/gemm/CMakeLists.txt b/tile_engine/ops/gemm/CMakeLists.txt index d28017ca0c..bc613a931e 100644 --- a/tile_engine/ops/gemm/CMakeLists.txt +++ b/tile_engine/ops/gemm/CMakeLists.txt @@ -8,6 +8,10 @@ execute_process( --list_blobs RESULT_VARIABLE ret ) +set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS + ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py + ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json +) if(ret AND NOT ret EQUAL 0) message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}") @@ -21,7 +25,9 @@ add_custom_command( --working_path ${CMAKE_CURRENT_BINARY_DIR} --json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json --gen_blobs - DEPENDS ${GEMM_CODEGEN_BLOBS} + DEPENDS ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py + ${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt + ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json ) set(EXECUTABLE_GEMM_INSTANCE "tile_engine_gemm") diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index b6c7685fb2..b441bdd2d6 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -27,7 +27,9 @@ LAYOUT_MAP = {'r' : 'ck_tile::tensor_layout::gemm::RowMajor', DEFAULT_EPILOGUE = """ using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue< - ck_tile::DefaultGemm2DEpilogueProblem