diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 0723026836..c540ce2bd4 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -234,4 +234,44 @@ CK_TILE_HOST_DEVICE constexpr index_t get_smem_capacity() #endif } +// ============================================================================ +// Architecture-specific parameter definitions +// We'll define all parameters for all supported architectures here. +// ============================================================================ + +// Parameters for gfx120x (using a namespace for organization or just global constexpr) +namespace Gfx120x { +constexpr ck_tile::index_t WarpTile = 32; +constexpr ck_tile::index_t VecLen = 8; +} + +// Parameters for gfx90x (example values, adjust as needed) +namespace Gfx90x { +constexpr ck_tile::index_t WarpTile = 64; +constexpr ck_tile::index_t VecLen = 4; +} + +// Generic Parameters - should never be used in this example +// templated run function should only be instantiated for Gfx120x and Gfx90x +namespace Generic { +constexpr ck_tile::index_t WarpTile = -1; +} + +// Helper to get VecLen based on GfxId +template +struct GfxConfig +{ + static constexpr int get_vec_len() + { + if constexpr (GfxId == 1200) + { + return Gfx120x::VecLen; + } + else if constexpr (GfxId == 900) + { + return Gfx90x::VecLen; + } + } +}; + } // namespace ck_tile diff --git a/include/ck_tile/core/arch/warp_tile_size.hpp b/include/ck_tile/core/arch/warp_tile_size.hpp deleted file mode 100644 index c3751da219..0000000000 --- a/include/ck_tile/core/arch/warp_tile_size.hpp +++ /dev/null @@ -1,20 +0,0 @@ -// ============================================================================ -// Architecture-specific parameter definitions -// We'll define all parameters for all supported architectures here. -// ============================================================================ - -// Parameters for gfx120x (using a namespace for organization or just global constexpr) -namespace Gfx120x { -constexpr ck_tile::index_t WarpTile = 32; -} - -// Parameters for gfx90x (example values, adjust as needed) -namespace Gfx90x { -constexpr ck_tile::index_t WarpTile = 64; -} - -// Generic Parameters - should never be used in this example -// templated run function should only be instantiated for Gfx120x and Gfx90x -namespace Generic { -constexpr ck_tile::index_t WarpTile = -1; -} diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index c201293389..82695697a6 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -51,8 +51,12 @@ #include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" +#if 0 #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp" +#endif +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_generic.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_generic_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 272426a0e6..6627bee87c 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -5,16 +5,22 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp" + +#if 0 #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp" - #include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp" +#endif + +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_generic.hpp" + namespace ck_tile { // fp16 +#if 0 using WarpGemmMfmaF16F16F32M32N32K8 = WarpGemmImpl< WarpGemmAttributeMfma>>; @@ -23,7 +29,15 @@ using WarpGemmMfmaF16F16F32M16N16K16 = WarpGemmImpl< using WarpGemmWmmaF16F16F32M16N16K16 = WarpGemmImpl< WarpGemmAttributeWmma>>; +#endif +using WarpGemmMfmaF16F16F32M16N16K16 = WarpGemmImpl< + WarpGemmAttributeGeneric>>; + +using WarpGemmWmmaF16F16F32M16N16K16 = WarpGemmImpl< + WarpGemmAttributeGeneric>>; + +#if 0 #if defined(__gfx950__) template using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl< @@ -343,4 +357,6 @@ using WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed = WarpGemmImpl>>; +#endif + } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_generic.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_generic.hpp new file mode 100644 index 0000000000..15d08b3084 --- /dev/null +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_generic.hpp @@ -0,0 +1,452 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_generic_impl_F16F16F32M16N16K16.hpp" + +namespace ck_tile { + +// Number of groups of consecutive elements to fill in a ABKLane +enum class WGAttrNumAccessEnum +{ + Single = 1, + Double = 2, + Quad = 4, + Invalid = -1 +}; + +template +struct WarpGemmAttributeGeneric +{ + using Impl = remove_cvref_t; + static constexpr auto AttrNumAccess = AttrNumAccess_; + static constexpr auto AttrNumAccessV = static_cast(AttrNumAccess); + + using ADataType = typename Impl::ADataType; + using BDataType = typename Impl::BDataType; + using CDataType = typename Impl::CDataType; + + using AVecType = typename Impl::AVecType; + using BVecType = typename Impl::BVecType; + using CVecType = typename Impl::CVecType; + + static constexpr index_t kM = Impl::kM; + static constexpr index_t kN = Impl::kN; + static constexpr index_t kK = Impl::kK; + static constexpr index_t kKPerThread = Impl::kABKPerLane; + static constexpr index_t kCMLane = Impl::kCMLane; + + CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } + + static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1, + "Multi-block WarpGemmAttributeGenericImpl is not supported"); + + template + static constexpr auto get_warp_dstr_encoding() + { + if constexpr(AttrNumAccessV == 1) + { + return tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + else + { + static_assert(kKPerThread % AttrNumAccessV == 0, + "kKPerThread must be divisible by NumAccess"); + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>{}; + } + } + using AWarpDstrEncoding = decltype(get_warp_dstr_encoding()); + using BWarpDstrEncoding = decltype(get_warp_dstr_encoding()); + + using CWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 1>, + sequence<0, 2>>; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + Impl{}(c_vec, a_vec, b_vec, bool_constant{}); + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { + return Impl{}(a_vec, b_vec); + } +}; + +template +struct WarpGemmAttributeGenericIterateK +{ + static_assert(kKIter > 0, "wrong!"); + + using Impl = remove_cvref_t; + static constexpr auto AttrNumAccess = AttrNumAccess_; + static constexpr auto AttrNumAccessV = static_cast(AttrNumAccess); + + using ADataType = typename Impl::ADataType; + using BDataType = typename Impl::BDataType; + using CDataType = typename Impl::CDataType; + + using AVecType = + ext_vector_t::vector_size * kKIter>; + using BVecType = + ext_vector_t::vector_size * kKIter>; + using CVecType = typename Impl::CVecType; + + static constexpr index_t kM = Impl::kM; + static constexpr index_t kN = Impl::kN; + static constexpr index_t kK = Impl::kK * kKIter; + static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter; + + CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; } + + static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1, + "Multi-block on both M & N directions is not supported"); + + CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding() + { + if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1) + { + if constexpr(AttrNumAccessV == 1) + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + else + { + static_assert(kKPerThread % AttrNumAccessV == 0, + "kKPerThread must be divisible by NumAccess"); + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>{}; + } + } + else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock) + { + static_assert(AttrNumAccessV == 1, + "Multiple access is not supported when using multi-block"); + // each M blocks share the same data + return tile_distribution_encoding< + sequence, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1) + { + static_assert(AttrNumAccessV == 1, + "Multiple access is not supported when using multi-block"); + // single block to multi-block thread mapping + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + } + + CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding() + { + if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1) + { + if constexpr(AttrNumAccessV == 1) + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + else + { + + static_assert(kKPerThread % AttrNumAccessV == 0, + "kKPerThread must be divisible by NumAccess"); + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>{}; + } + } + else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock) + { + static_assert(AttrNumAccessV == 1, + "Multiple access is not supported when using multi-block"); + // single block to multi-block thread mapping + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1) + { + static_assert(AttrNumAccessV == 1, + "Multiple access is not supported when using multi-block"); + // each N blocks share the same data + return tile_distribution_encoding< + sequence, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + } + + CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding() + { + if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1) + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 1>, + sequence<0, 2>>{}; + } + else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock) + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 1>, + sequence<0, 2>>{}; + } + else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1) + { + return tile_distribution_encoding< + sequence<>, + tuple< + sequence, + sequence>, + tuple>, + tuple>, + sequence<1, 1>, + sequence<0, 2>>{}; + } + } + + using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding()); + + using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding()); + + using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding()); + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + using buf_a = thread_buffer; + using buf_b = thread_buffer; + + static_for<0, kKIter, 1>{}([&](auto iKIter) { + Impl{}(c_vec, + reinterpret_cast(a_vec) + .template get_as()[iKIter], + reinterpret_cast(b_vec) + .template get_as()[iKIter], + bool_constant{}); + }); + } + + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + number, + bool_constant = {}) const + { + using buf_a = thread_buffer; + using buf_b = thread_buffer; + + static_assert(iKIter < kKIter); + + // static_for<0, kKIter, 1>{}([&](auto iKIter) { + Impl{}(c_vec, + reinterpret_cast(a_vec) + .template get_as()[iKIter], + reinterpret_cast(b_vec) + .template get_as()[iKIter], + bool_constant{}); + //}); + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { + constexpr auto I0 = number<0>{}; + using buf_a = thread_buffer; + using buf_b = thread_buffer; + + // c = a * b + auto c_vec = Impl{}( + reinterpret_cast(a_vec).template get_as()[I0], + reinterpret_cast(b_vec).template get_as()[I0]); + + // c += a * b + static_for<1, kKIter, 1>{}([&](auto iKIter) { + Impl{}(c_vec, + reinterpret_cast(a_vec) + .template get_as()[iKIter], + reinterpret_cast(b_vec) + .template get_as()[iKIter]); + }); + + return c_vec; + } +}; + +template +struct WarpGemmAttributeGenericTransposedCDistribution +{ + using Impl = remove_cvref_t; + static constexpr auto AttrNumAccess = AttrNumAccess_; + static constexpr auto AttrNumAccessV = static_cast(AttrNumAccess); + + using ADataType = typename Impl::BDataType; + using BDataType = typename Impl::ADataType; + using CDataType = typename Impl::CDataType; + + using AVecType = typename Impl::BVecType; + using BVecType = typename Impl::AVecType; + using CVecType = typename Impl::CVecType; + + static constexpr index_t kM = Impl::kN; + static constexpr index_t kN = Impl::kM; + static constexpr index_t kK = Impl::kK; + static constexpr index_t kKPerThread = Impl::kABKPerLane; + + CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } + + static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1, + "Multi-block WarpGemmAttributeGenericImpl is not supported"); + + template + static constexpr auto get_warp_dstr_encoding() + { + if constexpr(AttrNumAccessV == 1) + { + return tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + else + { + static_assert(kKPerThread % AttrNumAccessV == 0, + "kKPerThread must be divisible by NumAccess"); + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>{}; + } + } + using AWarpDstrEncoding = decltype(get_warp_dstr_encoding()); + using BWarpDstrEncoding = decltype(get_warp_dstr_encoding()); + + using CWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + // swap A and B + Impl{}(c_vec, b_vec, a_vec, bool_constant{}); + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { + // swap A and B + return Impl{}(b_vec, a_vec); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_generic_impl_F16F16F32M16N16K16.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_generic_impl_F16F16F32M16N16K16.hpp new file mode 100644 index 0000000000..e7620260a9 --- /dev/null +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_generic_impl_F16F16F32M16N16K16.hpp @@ -0,0 +1,132 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// TODO: refactor warp-gemm +// currently there is a discrepency for vav/vva if we need transpose C/D +// e.g. if we want A:agpr, B:vgpr, we have to use vva in WGAttrEnum +// because we swap the A/B pointer in _impl code (but not known this info here) +enum class WGAttrCtlEnum +{ + Default_ = 0, + Raw_vvv = 1, // c-vgpr, a-vgpr, b-vgpr + Raw_vaa = 2, // c-vgpr, a-agpr, b-agpr + Raw_vav = 3, // c-vgpr, a-agpr, b-vgpr + Raw_vva = 4, // c-vgpr, a-vgpr, b-agpr + Raw_avv = 5, // c-agpr, a-vgpr, b-vgpr + // raw_a_a_a = 3, // c-agpr, a-agpr, b-agpr +}; + +#define DISPATCH_MFMA_(mfma_, dmod_, amod_, bmod_, cmod_) \ + if constexpr(post_nop_) \ + { \ + asm volatile(mfma_ " %0, %1, %2, %3 ; yyy\n" \ + "s_nop 3" \ + : dmod_(c_vec) \ + : amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \ + :); \ + } \ + else \ + { \ + asm volatile(mfma_ " %0, %1, %2, %3\n" \ + : dmod_(c_vec) \ + : amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \ + :); \ + } + +#define DISPATCH_MFMA_CTRL_(mfma_, ctrl_) \ + if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vvv) \ + { \ + DISPATCH_MFMA_(mfma_, "+v", "v", "v", "v") \ + } \ + else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vaa) \ + { \ + DISPATCH_MFMA_(mfma_, "+v", "a", "a", "v") \ + } \ + else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vav) \ + { \ + DISPATCH_MFMA_(mfma_, "+v", "a", "v", "v") \ + } \ + else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vva) \ + { \ + DISPATCH_MFMA_(mfma_, "+v", "v", "a", "v") \ + } \ + else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_avv) \ + { \ + DISPATCH_MFMA_(mfma_, "+a", "v", "v", "a") \ + } + + +template +struct WarpGemmAttributeGenericImplF16F16F32M16N16K16 +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = fp16_t; + using BDataType = fp16_t; + using CDataType = float; + + static constexpr int VecLen = GfxConfig::get_vec_len(); + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 16; + static constexpr index_t kN = 16; + static constexpr index_t kK = 16; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABKLane = 4; + static constexpr index_t kABKPerLane = 4; + + static constexpr index_t kCMLane = 4; + static constexpr index_t kCNLane = 16; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16f16", Ctrl) + else + { +#if defined(__gfx9__) + c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0); +#elif defined(__gfx12__) + c_vec = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_vec, b_vec, c_vec); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; +#endif + } + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { +#if defined(__gfx9__) + return bit_cast( + __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); +#elif defined(__gfx12__) + return bit_cast( + __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_vec, b_vec, fp32x8_t{0.f})); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + return CVecType{0.f}; +#endif + } +}; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_1.hpp similarity index 99% rename from include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp rename to include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_1.hpp index f000092b72..1614d6fb87 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_1.hpp @@ -334,8 +334,8 @@ struct WarpGemmAttributeWmmaTransposedCDistribution sequence>, tuple>, tuple>, - sequence<2, 2>, - sequence<0, 2>>; + sequence<2, 4>, + sequence<0, 4>>; // c_vec += a_vec * b_vec template diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_1.hpp similarity index 100% rename from include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp rename to include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_1.hpp