From bb1f6e48eb212b6380d53055e7e57ba340aa85f3 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Tue, 19 Mar 2024 23:29:51 +0000 Subject: [PATCH] fix fp8 duplicated move/shift/and/or problem --- include/ck_tile/core.hpp | 1 + include/ck_tile/host.hpp | 1 + include/ck_tile/ops/common.hpp | 1 + include/ck_tile/ops/epilogue.hpp | 1 + include/ck_tile/ops/fmha.hpp | 1 + .../gemm/warp/warp_gemm_attribute_mfma.hpp | 82 +++++++--- .../warp/warp_gemm_attribute_mfma_impl.hpp | 143 +++++------------- .../ck_tile/ops/gemm/warp/warp_gemm_impl.hpp | 12 +- 8 files changed, 108 insertions(+), 134 deletions(-) diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 6b1c11fa27..9ac55c1197 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -54,3 +54,4 @@ #include "ck_tile/core/utility/to_sequence.hpp" #include "ck_tile/core/utility/transpose_vectors.hpp" #include "ck_tile/core/utility/type_traits.hpp" + diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index 0c4a778226..1bbb4b9539 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -20,3 +20,4 @@ #include "ck_tile/host/reference/reference_reduce.hpp" #include "ck_tile/host/reference/reference_softmax.hpp" #include "ck_tile/host/stream_config.hpp" + diff --git a/include/ck_tile/ops/common.hpp b/include/ck_tile/ops/common.hpp index 4363ea1f55..9fc1c0d0c1 100644 --- a/include/ck_tile/ops/common.hpp +++ b/include/ck_tile/ops/common.hpp @@ -4,3 +4,4 @@ #pragma once #include "ck_tile/ops/common/tensor_layout.hpp" + diff --git a/include/ck_tile/ops/epilogue.hpp b/include/ck_tile/ops/epilogue.hpp index 388f52c898..ab399dbf7a 100644 --- a/include/ck_tile/ops/epilogue.hpp +++ b/include/ck_tile/ops/epilogue.hpp @@ -5,3 +5,4 @@ #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" + diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 1e9acc6d7b..f886d470d5 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -18,3 +18,4 @@ #include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" + diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp index 420870e61d..2c1335c61b 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -75,8 +75,10 @@ struct WarpGemmAtrributeMfmaIterateK using BDataType = typename Impl::BDataType; using CDataType = typename Impl::CDataType; - using AVecType = array; - using BVecType = array; + using AVecType = + ext_vector_t::vector_size * kKIter>; + using BVecType = + ext_vector_t::vector_size * kKIter>; using CVecType = typename Impl::CVecType; static constexpr index_t kM = Impl::kM; @@ -112,10 +114,15 @@ struct WarpGemmAtrributeMfmaIterateK CK_TILE_DEVICE void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const { + using buf_a = thread_buffer; + using buf_b = thread_buffer; + static_for<0, kKIter, 1>{}([&](auto iKIter) { Impl{}(c_vec, - a_vec.template get_as()[iKIter], - b_vec.template get_as()[iKIter]); + reinterpret_cast(a_vec) + .template get_as()[iKIter], + reinterpret_cast(b_vec) + .template get_as()[iKIter]); }); } @@ -123,16 +130,21 @@ struct WarpGemmAtrributeMfmaIterateK CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { constexpr auto I0 = number<0>{}; + using buf_a = thread_buffer; + using buf_b = thread_buffer; // c = a * b - auto c_vec = Impl{}(a_vec.template get_as()[I0], - b_vec.template get_as()[I0]); + auto c_vec = Impl{}( + reinterpret_cast(a_vec).template get_as()[I0], + reinterpret_cast(b_vec).template get_as()[I0]); // c += a * b static_for<1, kKIter, 1>{}([&](auto iKIter) { Impl{}(c_vec, - a_vec.template get_as()[iKIter], - b_vec.template get_as()[iKIter]); + reinterpret_cast(a_vec) + .template get_as()[iKIter], + reinterpret_cast(b_vec) + .template get_as()[iKIter]); }); return c_vec; @@ -269,8 +281,10 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution using BDataType = typename Impl::ADataType; using CDataType = typename Impl::CDataType; - using AVecType = array; - using BVecType = array; + using AVecType = + ext_vector_t::vector_size * kKIter>; + using BVecType = + ext_vector_t::vector_size * kKIter>; using CVecType = typename Impl::CVecType; static constexpr index_t kM = Impl::kN; @@ -306,11 +320,15 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution CK_TILE_DEVICE void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const { + using buf_a = thread_buffer; + using buf_b = thread_buffer; // swap A and B, value and type static_for<0, kKIter, 1>{}([&](auto iKIter) { Impl{}(c_vec, - b_vec.template get_as()[iKIter], - a_vec.template get_as()[iKIter]); + reinterpret_cast(b_vec) + .template get_as()[iKIter], + reinterpret_cast(a_vec) + .template get_as()[iKIter]); }); } @@ -318,15 +336,20 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { constexpr auto I0 = number<0>{}; + using buf_a = thread_buffer; + using buf_b = thread_buffer; // swap A and B, value and type - auto c_vec = Impl{}(b_vec.template get_as()[I0], - a_vec.template get_as()[I0]); + auto c_vec = Impl{}( + reinterpret_cast(b_vec).template get_as()[I0], + reinterpret_cast(a_vec).template get_as()[I0]); static_for<1, kKIter, 1>{}([&](auto iKIter) { Impl{}(c_vec, - b_vec.template get_as()[iKIter], - a_vec.template get_as()[iKIter]); + reinterpret_cast(b_vec) + .template get_as()[iKIter], + reinterpret_cast(a_vec) + .template get_as()[iKIter]); }); return c_vec; @@ -343,8 +366,10 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB using BDataType = typename Impl::ADataType; using CDataType = typename Impl::CDataType; - using AVecType = array; - using BVecType = array; + using AVecType = + ext_vector_t::vector_size * kKIter>; + using BVecType = + ext_vector_t::vector_size * kKIter>; using CVecType = typename Impl::CVecType; static constexpr index_t kM = Impl::kN; @@ -406,27 +431,36 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB CK_TILE_DEVICE void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const { + using buf_a = thread_buffer; + using buf_b = thread_buffer; // swap A and B, value and type static_for<0, kKIter, 1>{}([&](auto iKIter) { Impl{}(c_vec, - b_vec.template get_as()[iKIter], - a_vec.template get_as()[iKIter]); + reinterpret_cast(b_vec) + .template get_as()[iKIter], + reinterpret_cast(a_vec) + .template get_as()[iKIter]); }); } // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { + using buf_a = thread_buffer; + using buf_b = thread_buffer; constexpr auto I0 = number<0>{}; // swap A and B, value and type - auto c_vec = Impl{}(b_vec.template get_as()[I0], - a_vec.template get_as()[I0]); + auto c_vec = Impl{}( + reinterpret_cast(b_vec).template get_as()[I0], + reinterpret_cast(a_vec).template get_as()[I0]); static_for<1, kKIter, 1>{}([&](auto iKIter) { Impl{}(c_vec, - b_vec.template get_as()[iKIter], - a_vec.template get_as()[iKIter]); + reinterpret_cast(b_vec) + .template get_as()[iKIter], + reinterpret_cast(a_vec) + .template get_as()[iKIter]); }); return c_vec; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index ecc4165e09..e618d66a75 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -14,9 +14,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 using BDataType = fp16_t; using CDataType = float; - using AVecType = array; - using BVecType = array; - using CVecType = array; + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; static constexpr index_t kM = 32; static constexpr index_t kN = 32; @@ -36,25 +36,14 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 CK_TILE_DEVICE void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const { - c_vec.template get_as()[number<0>{}] = - __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec.template get_as()[number<0>{}], - b_vec.template get_as()[number<0>{}], - c_vec.template get_as()[number<0>{}], - 0, - 0, - 0); + c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0); } // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { return bit_cast( - __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec.template get_as()[number<0>{}], - b_vec.template get_as()[number<0>{}], - fp32x16_t{0.f}, - 0, - 0, - 0)); + __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0)); } }; @@ -64,9 +53,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 using BDataType = fp16_t; using CDataType = float; - using AVecType = array; - using BVecType = array; - using CVecType = array; + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; static constexpr index_t kM = 16; static constexpr index_t kN = 16; @@ -86,25 +75,14 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 CK_TILE_DEVICE void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const { - c_vec.template get_as()[number<0>{}] = - __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec.template get_as()[number<0>{}], - b_vec.template get_as()[number<0>{}], - c_vec.template get_as()[number<0>{}], - 0, - 0, - 0); + c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0); } // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { return bit_cast( - __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec.template get_as()[number<0>{}], - b_vec.template get_as()[number<0>{}], - fp32x4_t{0.f}, - 0, - 0, - 0)); + __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); } }; @@ -115,9 +93,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 using BDataType = bf16_t; using CDataType = float; - using AVecType = array; - using BVecType = array; - using CVecType = array; + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; static constexpr index_t kM = 32; static constexpr index_t kN = 32; @@ -137,25 +115,14 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 CK_TILE_DEVICE void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const { - c_vec.template get_as()[number<0>{}] = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k( - a_vec.template get_as()[number<0>{}], - b_vec.template get_as()[number<0>{}], - c_vec.template get_as()[number<0>{}], - 0, - 0, - 0); + c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); } // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { return bit_cast( - __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec.template get_as()[number<0>{}], - b_vec.template get_as()[number<0>{}], - fp32x16_t{0.f}, - 0, - 0, - 0)); + __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0)); } }; @@ -165,9 +132,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 using BDataType = bf16_t; using CDataType = float; - using AVecType = array; - using BVecType = array; - using CVecType = array; + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; static constexpr index_t kM = 16; static constexpr index_t kN = 16; @@ -187,25 +154,14 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 CK_TILE_DEVICE void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const { - c_vec.template get_as()[number<0>{}] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k( - a_vec.template get_as()[number<0>{}], - b_vec.template get_as()[number<0>{}], - c_vec.template get_as()[number<0>{}], - 0, - 0, - 0); + c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); } // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { - return bit_cast(__builtin_amdgcn_mfma_f32_16x16x16bf16_1k( - a_vec.template get_as()[number<0>{}], - b_vec.template get_as()[number<0>{}], - fp32x4_t{0.f}, - 0, - 0, - 0)); + return bit_cast( + __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); } }; @@ -217,9 +173,9 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base using BDataType = BType_; using CDataType = float; - using AVecType = array; - using BVecType = array; - using CVecType = array; + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; static constexpr index_t kM = 32; static constexpr index_t kN = 32; @@ -241,48 +197,27 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base { #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) if constexpr(std::is_same_v && std::is_same_v) - c_vec.template get_as()[number<0>{}] = - __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( - bit_cast(a_vec), - bit_cast(b_vec), - c_vec.template get_as()[number<0>{}], - 0, - 0, - 0); + c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); else if constexpr(std::is_same_v && std::is_same_v) - c_vec.template get_as()[number<0>{}] = - __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8( - bit_cast(a_vec), - bit_cast(b_vec), - c_vec.template get_as()[number<0>{}], - 0, - 0, - 0); + c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8( + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); else if constexpr(std::is_same_v && std::is_same_v) - c_vec.template get_as()[number<0>{}] = - __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8( - bit_cast(a_vec), - bit_cast(b_vec), - c_vec.template get_as()[number<0>{}], - 0, - 0, - 0); + c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8( + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); else if constexpr(std::is_same_v && std::is_same_v) - c_vec.template get_as()[number<0>{}] = - __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8( - bit_cast(a_vec), - bit_cast(b_vec), - c_vec.template get_as()[number<0>{}], - 0, - 0, - 0); + c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8( + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); #else static_for<0, 8, 1>{}([&](auto k) { - float a_f32 = type_convert(a_vec.template get_as()[number{}]); - float b_f32 = type_convert(b_vec.template get_as()[number{}]); + float a_f32 = + type_convert(reinterpret_cast&>(a_vec) + .template get_as()[number{}]); + float b_f32 = + type_convert(reinterpret_cast&>(b_vec) + .template get_as()[number{}]); - c_vec.template get_as()[number<0>{}] = __builtin_amdgcn_mfma_f32_32x32x2f32( - a_f32, b_f32, c_vec.template get_as()[number<0>{}], 0, 0, 0); + c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0); }); #endif } diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp index 02c8812d49..843d091c48 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp @@ -33,9 +33,9 @@ struct WarpGemmImpl CK_TILE_DEVICE void operator()(CWarpTensor& c, const AWarpTensor& a, const BWarpTensor& b) const { - using AVec = array; - using BVec = array; - using CVec = array; + using AVec = ext_vector_t; + using BVec = ext_vector_t; + using CVec = ext_vector_t; constexpr auto I0 = number<0>{}; @@ -53,9 +53,9 @@ struct WarpGemmImpl { CWarpTensor c; - using AVec = array; - using BVec = array; - using CVec = array; + using AVec = ext_vector_t; + using BVec = ext_vector_t; + using CVec = ext_vector_t; constexpr auto I0 = number<0>{};