From 00d05ab32ef0b0e3faab0e6d99aee2286b2b75f7 Mon Sep 17 00:00:00 2001 From: joyeamd Date: Tue, 6 Jan 2026 15:39:00 +0800 Subject: [PATCH] Merge some updates for ck_tile headers (#3342) * fix some issues from internal branch * update cshuffle_epilogue * update cshuffle_epilogue * update cshuffle * update warp_gemm [ROCm/composable_kernel commit: b78563b3d3edf1b2cd686ff0c0994ca2538419ef] --- include/ck_tile/core/arch/arch.hpp | 26 ++-- .../ck_tile/core/tensor/transpose_tile.hpp | 29 +--- .../ops/epilogue/cshuffle_epilogue.hpp | 35 +++-- .../ops/gemm/kernel/universal_gemm_kernel.hpp | 124 ++++++++++++------ ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 13 +- .../ops/gemm/pipeline/tile_gemm_traits.hpp | 8 +- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 10 ++ .../gemm/warp/warp_gemm_attribute_wmma.hpp | 17 ++- .../warp/warp_gemm_attribute_wmma_impl.hpp | 7 +- ..._gemm_attribute_wmma_impl_16bit_traits.hpp | 8 ++ ...p_gemm_attribute_wmma_impl_8bit_traits.hpp | 10 ++ ...p_gemm_attribute_wmma_impl_base_traits.hpp | 4 + .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 2 + .../gemm/test_gemm_pipeline_ut_cases.inc | 31 +++-- 14 files changed, 205 insertions(+), 119 deletions(-) diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index c5c1a6e2c6..97e962f5a3 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -1124,8 +1124,14 @@ CK_TILE_DEVICE static constexpr auto get_device_arch() { // FIXME(0): on all devices except gfx11 it returns gfx12_t // FIXME(1): during the host compilation pass it returns gfx12_t -#if defined(__gfx11__) +#if defined(__gfx103__) + return gfx103_t{}; +#elif defined(__gfx11__) return gfx11_t{}; +#elif defined(__gfx950__) + return gfx950_t{}; +#elif defined(__gfx9__) + return gfx9_t{}; #else return gfx12_t{}; #endif @@ -1146,26 +1152,10 @@ CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx950_t) { return 64; } CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx_invalid_t) { return 0; } -CK_TILE_DEVICE static constexpr auto arch_tag_dispatch() -{ -#if defined(__gfx103__) - return gfx103_t{}; -#elif defined(__gfx11__) - return gfx11_t{}; -#elif defined(__gfx12__) - return gfx12_t{}; -#elif defined(__gfx950__) - return gfx950_t{}; -#elif defined(__gfx9__) - return gfx9_t{}; -#else - return gfx_invalid_t{}; -#endif -} } // namespace detail CK_TILE_DEVICE static constexpr auto get_n_lds_banks() { - return detail::get_n_lds_banks(detail::arch_tag_dispatch()); + return detail::get_n_lds_banks(get_device_arch()); } enum LLVMSchedGroupMask : int32_t diff --git a/include/ck_tile/core/tensor/transpose_tile.hpp b/include/ck_tile/core/tensor/transpose_tile.hpp index e5a0664ec9..50927c5ca4 100644 --- a/include/ck_tile/core/tensor/transpose_tile.hpp +++ b/include/ck_tile/core/tensor/transpose_tile.hpp @@ -34,46 +34,23 @@ CK_TILE_DEVICE void transpose_tile2d_impl_in_thread(OutTensor& out_tensor, constexpr auto y_in_desc = InTensor::get_tile_distribution().get_ys_to_d_descriptor(); constexpr auto y_out_desc = OutTensor::get_tile_distribution().get_ys_to_d_descriptor(); - // y_dim_out_to_in - // For swapped Hs tile case I need only get_rh_minor_to_y - // since rh_major are already swapped due to swapped Hs. - constexpr auto get_rh_minor_to_y = [](auto dstr_tensor) { - using DstrEncode = typename decltype(dstr_tensor.get_tile_distribution())::DstrEncode; - - map rh_minor_to_y_; - - static_for<0, DstrEncode::NDimY, 1>{}([&](auto i) { - constexpr index_t rh_minor = DstrEncode::ys_to_rhs_minor_[i]; - - rh_minor_to_y_(rh_minor) = i; - }); - - return rh_minor_to_y_; - }; - // In swapped Hs case -> tile // we have same rh_major, but reversed rh_minor! - constexpr auto rh_minor_to_y_in = get_rh_minor_to_y(InTensor{}); - constexpr auto rh_minor_to_y_out = get_rh_minor_to_y(OutTensor{}); + constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y(); - // Is this really needed?? Should we have simple reverse here?? constexpr auto y_dim_out_to_in = [&] { map y_dim_out_to_in_; - for(const auto& [rh_minor, y_out] : rh_minor_to_y_out) - { - y_dim_out_to_in_(y_out) = rh_minor_to_y_in[rh_minor]; - } + static_for<0, NDimY, 1>{}([&](auto i) { y_dim_out_to_in_(i) = NDimY - 1 - i; }); return y_dim_out_to_in_; }(); - constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y(); constexpr auto y_lengths = to_sequence(y_in_desc.get_lengths()); // input and output vector dim in the order of input Y dims constexpr index_t y_dim_vec_in = NDimY - 1; - constexpr index_t y_dim_vec_out = y_dim_out_to_in[NDimY - 1]; + constexpr index_t y_dim_vec_out = 0; // vector lengths constexpr index_t vec_length_in = y_lengths[y_dim_vec_in]; diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index c73897f064..97f936fde9 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -333,14 +333,30 @@ struct CShuffleEpilogue { constexpr int RakedXDLN_PerWarp = NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp; // BlockedLayout - return tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple>, - tuple>, - sequence<1, 2, 2>, - sequence<0, 0, 2>>{}; + // this branch is for original a16w4 + if constexpr(is_any_of::value || + is_any_of::value) + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 2, 2>, + sequence<0, 0, 2>>{}; + } + else + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 2, 2>, + sequence<0, 0, 1>>{}; + } } }(); constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding( @@ -351,7 +367,8 @@ struct CShuffleEpilogue CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { - return MPerIterationShuffle * NPerIterationShuffle * sizeof(ODataType); + constexpr auto lds_block_desc = MakeLdsBlockDescriptor(); + return lds_block_desc.get_element_space_size() * sizeof(ODataType); } template diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index c77459b4ec..628f5f7dc8 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -423,7 +423,7 @@ struct UniversalGemmKernel const auto vectorSizeA = is_wave32() ? GemmPipeline::template GetVectorSizeA() : GemmPipeline::template GetVectorSizeA(); - bool AsTesnorIsValid = {true}; + bool AsTensorIsValid = {true}; static_for<0, NumATensor, 1>{}([&](auto index) { using AiLayout = remove_cvref_t>; if constexpr(std::is_same_v) @@ -437,15 +437,27 @@ struct UniversalGemmKernel "Can't support K that is not a multiple of k_batch * KPerBlock " "without padding!"); } - AsTesnorIsValid = false; + AsTensorIsValid = false; } if(kargs.K % vectorSizeA != 0) { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + const auto remainder = kargs.K % vectorSizeA; + constexpr ck_tile::index_t APackedSize = + ck_tile::numeric_traits::PackedSize; + const auto remainder_in_bytes = remainder * sizeof(ADataType) / APackedSize; + // oob can support to dword level + if(remainder_in_bytes % 4 == 0) { - CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!"); + AsTensorIsValid = true; + } + else + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!"); + } + AsTensorIsValid = false; } - AsTesnorIsValid = false; } } else @@ -457,20 +469,33 @@ struct UniversalGemmKernel CK_TILE_ERROR( "Can't support M that is not a multiple of MPerBlock without padding!"); } - AsTesnorIsValid = false; + AsTensorIsValid = false; } if(kargs.M % vectorSizeA != 0) { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + const auto remainder = kargs.M % vectorSizeA; + constexpr ck_tile::index_t APackedSize = + ck_tile::numeric_traits::PackedSize; + const auto remainder_in_bytes = remainder * sizeof(ADataType) / APackedSize; + // oob can support to dword level + if(remainder_in_bytes % 4 == 0) { - CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!"); + + AsTensorIsValid = true; + } + else + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!"); + } + AsTensorIsValid = false; } - AsTesnorIsValid = false; } } }); - bool BsTesnorIsValid = {true}; + bool BsTensorIsValid = {true}; const auto vectorSizeB = is_wave32() ? GemmPipeline::template GetVectorSizeB() : GemmPipeline::template GetVectorSizeB(); static_for<0, NumBTensor, 1>{}([&](auto index) { @@ -484,47 +509,72 @@ struct UniversalGemmKernel CK_TILE_ERROR( "Can't support N that is not a multiple of NPerBlock without padding!"); } - BsTesnorIsValid = false; + BsTensorIsValid = false; } if(kargs.N % vectorSizeB != 0) { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + const auto remainder = kargs.N % vectorSizeB; + constexpr ck_tile::index_t BPackedSize = + ck_tile::numeric_traits::PackedSize; + const auto remainder_in_bytes = remainder * sizeof(BDataType) / BPackedSize; + // oob can support to dword level + if(remainder_in_bytes % 4 == 0) { - CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!"); + BsTensorIsValid = true; + } + else + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!"); + } + BsTensorIsValid = false; } - BsTesnorIsValid = false; } - } - else - { - if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 && - GemmPipeline::kPadK == false) + else { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 && + GemmPipeline::kPadK == false) { - CK_TILE_ERROR( - "Can't support K that is not a multiple of k_batch * KPerBlock " - "without padding!"); + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support K that is not a multiple of k_batch * KPerBlock " + "without padding!"); + } + BsTensorIsValid = false; } - BsTesnorIsValid = false; - } - if(kargs.K % vectorSizeB != 0) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + if(kargs.K % vectorSizeB != 0) { - CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!"); + const auto remainder = kargs.K % vectorSizeB; + constexpr ck_tile::index_t BPackedSize = + ck_tile::numeric_traits::PackedSize; + const auto remainder_in_bytes = remainder * sizeof(BDataType) / BPackedSize; + // oob can support to dword level + if(remainder_in_bytes % 4 == 0) + { + BsTensorIsValid = true; + } + else + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "K is not a multiple of vector load size for B tensor!"); + } + BsTensorIsValid = false; + } } - BsTesnorIsValid = false; } } }); - bool DTesnorIsValid = {true}; + bool DTensorIsValid = {true}; static_for<0, NumDTensor, 1>{}([&](auto index) { using DiLayout = remove_cvref_t>; if(std::is_same_v == false) { - DTesnorIsValid = false; + DTensorIsValid = false; } if constexpr(std::is_same_v) { @@ -535,7 +585,7 @@ struct UniversalGemmKernel CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of " "NPerBlock without padding!"); } - DTesnorIsValid = false; + DTensorIsValid = false; } if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0) { @@ -543,7 +593,7 @@ struct UniversalGemmKernel { CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!"); } - DTesnorIsValid = false; + DTensorIsValid = false; } } else @@ -555,7 +605,7 @@ struct UniversalGemmKernel CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of " "MPerBlock without padding!"); } - DTesnorIsValid = false; + DTensorIsValid = false; } if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0) { @@ -563,7 +613,7 @@ struct UniversalGemmKernel { CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!"); } - DTesnorIsValid = false; + DTensorIsValid = false; } } }); @@ -608,7 +658,7 @@ struct UniversalGemmKernel return false; } } - return AsTesnorIsValid && BsTesnorIsValid && DTesnorIsValid; + return AsTensorIsValid && BsTensorIsValid && DTensorIsValid; } CK_TILE_DEVICE static auto diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index d68da14ac5..6199142d98 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -845,10 +845,10 @@ struct UniversalGemmBasePolicy template CK_TILE_DEVICE static constexpr index_t GetSmemSizeA() { - constexpr index_t smem_size_a = - integer_least_multiple(sizeof(typename Problem::ADataType) * - Problem::BlockGemmShape::kM * Problem::BlockGemmShape::kK, - 16); + using ADataType = remove_cvref_t; + constexpr auto a_lds_block_desc = Derived::template MakeALdsBlockDescriptor(); + constexpr index_t smem_size_a = integer_least_multiple( + a_lds_block_desc.get_element_space_size() * sizeof(ADataType), 16); return smem_size_a; } @@ -859,8 +859,9 @@ struct UniversalGemmBasePolicy std::conditional_t, typename Problem::ADataType, typename Problem::BDataType>; - constexpr index_t smem_size_b = integer_least_multiple( - sizeof(BDataType) * Problem::BlockGemmShape::kN * Problem::BlockGemmShape::kK, 16); + constexpr auto b_lds_block_desc = Derived::template MakeBLdsBlockDescriptor(); + constexpr index_t smem_size_b = integer_least_multiple( + b_lds_block_desc.get_element_space_size() * sizeof(BDataType), 16); return smem_size_b; } diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp index 47607a40f5..5b00eb244b 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp @@ -53,11 +53,11 @@ struct TileGemmUniversalTraits static constexpr int _VectorSize = VectorSize_; static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_; - using AsLayout = AsLayout_; - using BsLayout = BsLayout_; - using CLayout = CLayout_; + using AsLayout = AsLayout_; + using BsLayout = BsLayout_; + using CLayout = CLayout_; + static constexpr bool TransposeC = TransposeC_; - static constexpr bool TransposeC = TransposeC_; static constexpr bool UseStructuredSparsity = UseStructuredSparsity_; static constexpr bool UsePersistentKernel = UsePersistentKernel_; static constexpr index_t NumWaveGroups = NumWaveGroups_; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index c0fbf8e5d3..7bcc9107da 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -306,6 +306,16 @@ using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl, 2>>; +using WarpGemmMfma_f32_16x16x64_fp8_fp8_CTransposed = + WarpGemmImpl, + 2>>; + +using WarpGemmMfma_f32_16x16x64_bf8_bf8_CTransposed = + WarpGemmImpl, + 2>>; + template using WarpGemmMfma_f32_16x16x128_f8f6f4 = WarpGemmImpl< WarpGemmAttributeMfma, AttrNumAccess>>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp index ff2ba501fe..ef31d06c9c 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp @@ -68,6 +68,19 @@ struct WarpGemmAttributeWmma { using Impl = remove_cvref_t; + // When kTransC is true and A/B types differ, we need an impl with swapped types + using TransposedImpl = + std::conditional_t, + WarpGemmAttributeWmmaImpl>, + Impl>; + using ADataType = typename Impl::ADataType; using BDataType = typename Impl::BDataType; using CDataType = typename Impl::CDataType; @@ -104,7 +117,7 @@ struct WarpGemmAttributeWmma { if constexpr(kTransC) { - Impl{}(c_vec, b_vec, a_vec, bool_constant{}); + TransposedImpl{}(c_vec, b_vec, a_vec, bool_constant{}); } else { @@ -117,7 +130,7 @@ struct WarpGemmAttributeWmma { if constexpr(kTransC) { - return Impl{}(b_vec, a_vec); + return TransposedImpl{}(b_vec, a_vec); } else { diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp index 0464ffbce4..cf0efbbaae 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp @@ -22,9 +22,10 @@ struct WmmaTraits; template struct WarpGemmAttributeWmmaImpl { - using ADataType = typename Traits::ADataType; - using BDataType = typename Traits::BDataType; - using CDataType = typename Traits::CDataType; + using TraitsType = Traits; + using ADataType = typename Traits::ADataType; + using BDataType = typename Traits::BDataType; + using CDataType = typename Traits::CDataType; using AVecType = typename Traits::AVecType; using BVecType = typename Traits::BVecType; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_16bit_traits.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_16bit_traits.hpp index 992f0a8783..d9d4ec9430 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_16bit_traits.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_16bit_traits.hpp @@ -10,6 +10,8 @@ template <> struct WmmaTraits : WmmaTraitsBase { + using ArchType = gfx11_t; + template CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) @@ -30,6 +32,8 @@ template <> struct WmmaTraits : WmmaTraitsBase { + using ArchType = gfx11_t; + template CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) @@ -50,6 +54,8 @@ template <> struct WmmaTraits : WmmaTraitsBase { + using ArchType = gfx12_t; + template CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) @@ -70,6 +76,8 @@ template <> struct WmmaTraits : WmmaTraitsBase { + using ArchType = gfx12_t; + template CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp index 34c4dbe551..eace7e3956 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp @@ -10,6 +10,8 @@ template <> struct WmmaTraits : WmmaTraitsBase { + using ArchType = gfx11_t; + template CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) @@ -35,6 +37,8 @@ template <> struct WmmaTraits : WmmaTraitsBase { + using ArchType = gfx12_t; + template CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) @@ -60,6 +64,8 @@ template <> struct WmmaTraits : WmmaTraitsBase { + using ArchType = gfx12_t; + template CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) @@ -80,6 +86,8 @@ template <> struct WmmaTraits : WmmaTraitsBase { + using ArchType = gfx12_t; + template CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) @@ -100,6 +108,8 @@ template <> struct WmmaTraits : WmmaTraitsBase { + using ArchType = gfx12_t; + template CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp index 524215ddfa..e00b9d772f 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp @@ -10,6 +10,8 @@ struct WmmaTraitsBase; template struct WmmaTraitsBase { + using ArchType = gfx11_t; + using ADataType = ADType; using BDataType = BDType; using CDataType = CDType; @@ -57,6 +59,8 @@ struct WmmaTraitsBase template struct WmmaTraitsBase { + using ArchType = gfx12_t; + using ADataType = ADType; using BDataType = BDType; using CDataType = CDType; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 82c6e43834..d6c21e88b5 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -100,6 +100,7 @@ template<> struct Dispatcher { using Ty template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8_CTransposed; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8_CTransposed; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; }; @@ -113,6 +114,7 @@ template<> struct Dispatcher { using Ty template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8_CTransposed; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8_CTransposed; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; }; // scale mfma based f8f6f4 diff --git a/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc b/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc index 6e7c086e55..5239b2d888 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc @@ -31,7 +31,14 @@ TYPED_TEST(TEST_SUITE_NAME, SmallM) if constexpr(std::is_same_v) { - EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + if(M * sizeof(typename TestFixture::ADataType) % 4 == 0) // oob fit dword + { + this->Run(M, N, K); + } + else + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } } else { @@ -84,7 +91,14 @@ TYPED_TEST(TEST_SUITE_NAME, MidLargeM) } else { - EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + if(M * sizeof(typename TestFixture::ADataType) % 4 == 0) // oob fit dword + { + this->Run(M, N, K); + } + else + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } } } else @@ -103,18 +117,7 @@ TYPED_TEST(TEST_SUITE_NAME, PaddK) for(int M : Ms) { - if constexpr(std::is_same_v) - { -#if defined(ARCH_GFX12) || defined(ARCH_GFX11) - this->Run(M, N, K); -#else - EXPECT_THROW(this->Run(M, N, K), std::runtime_error); -#endif - } - else - { - this->Run(M, N, K); - } + this->Run(M, N, K); } }